I want to merge multiple maps using Spark/Scala. The maps have a case class instance as value.
Following is the relevant code:
case class SampleClass(value1:Int,value2:Int)
val sampleDataDs = Seq(
("a",25,Map(1->SampleClass(1,2))),
("a",32,Map(1->SampleClass(3,4),2->SampleClass(1,2))),
("b",16,Map(1->SampleClass(1,2))),
("b",18,Map(2->SampleClass(10,15)))).toDF("letter","number","maps")
Output:
+------+-------+--------------------------+
|letter|number |maps |
+------+-------+--------------------------+
|a | 25 | [1-> [1,2]] |
|a | 32 | [1-> [3,4], 2 -> [1,2]] |
|b | 16 | [1 -> [1,2]] |
|b | 18 | [2 -> [10,15]] |
+------+-------+--------------------------+
I want to group the data based on the "letter" column so that the final dataset should have the below expected final output:
+------+---------------------------------+
|letter| maps |
+------+---------------------------------+
|a | [1-> [4,6], 2 -> [1,2]] |
|b | [1-> [1,2], 2 -> [10,15]] |
+------+---------------------------------+
I tried to group by "letter" and then apply an udf to aggregate the values in the map. Below is what I tried:
val aggregatedDs = SampleDataDs.groupBy("letter").agg(collect_list(SampleDataDs("maps")).alias("mapList"))
Output:
+------+----------------------------------------+
|letter| mapList |
+------+-------+--------------------------------+
|a | [[1-> [1,2]],[1-> [3,4], 2 -> [1,2]]] |
|b | [[1-> [1,2]],[2 -> [10,15]]] |
+------+----------------------------------------+
After this I tried to write an udf to merge the output of collect_list
and get the final output:
def mergeMap = udf { valSeq:Seq[Map[Int,SampleClass]] =>
valSeq.flatten.groupBy(_._1).mapValues(x=>(x.map(_._2.value1).reduce(_ + _),x.map(_._2.value2).reduce(_ + _)))
}
val aggMapDs = aggregatedDs.withColumn("aggValues",mergeMap(col("mapList")))
However it fails with the error message:
Failed to execute user defined function
Caused by :java.lang.classCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to SampleClass
My Spark version is 2.3.1. Any ideas how I can get the expected final output?