While someone has already asked about computing a Weighted Average in Spark, in this question, I'm asking about using Datasets/DataFrames instead of RDDs.
How do I compute a weighted average in Spark? I have two columns: counts and previous averages:
case class Stat(name:String, count: Int, average: Double)
val statset = spark.createDataset(Seq(Stat("NY", 1,5.0),
Stat("NY",2,1.5),
Stat("LA",12,1.0),
Stat("LA",15,3.0)))
I would like to be able to compute a weighted average like this:
display(statset.groupBy($"name").agg(sum($"count").as("count"),
weightedAverage($"count",$"average").as("average")))
One can use a UDF to get close:
val weightedAverage = udf(
(row:Row)=>{
val counts = row.getAs[WrappedArray[Int]](0)
val averages = row.getAs[WrappedArray[Double]](1)
val (count,total) = (counts zip averages).foldLeft((0,0.0)){
case((cumcount:Int,cumtotal:Double),(newcount:Int,newaverage:Double))=>(cumcount+newcount,cumtotal+newcount*newaverage)}
(total/count) // Tested by returning count here and then extracting. Got same result as sum.
}
)
display(statset.groupBy($"name").agg(sum($"count").as("count"),
weightedAverage(struct(collect_list($"count"),
collect_list($"average"))).as("average")))
(Thanks to answers to Passing a list of tuples as a parameter to a spark udf in scala for help in writing this)
Newbies: Use these imports:
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.collection.mutable.WrappedArray
Is there a way of accomplishing this with built-in column functions instead of UDFs? The UDF feels clunky and if the numbers get large you have to convert the Int's to Long's.