2

I want to merge multiple maps using Spark/Scala. The maps have a case class instance as value.

Following is the relevant code:

case class SampleClass(value1:Int,value2:Int)

val sampleDataDs = Seq(
      ("a",25,Map(1->SampleClass(1,2))),
      ("a",32,Map(1->SampleClass(3,4),2->SampleClass(1,2))),
      ("b",16,Map(1->SampleClass(1,2))),
      ("b",18,Map(2->SampleClass(10,15)))).toDF("letter","number","maps")

Output:

+------+-------+--------------------------+
|letter|number |maps                      |
+------+-------+--------------------------+
|a     |  25   | [1-> [1,2]]              |
|a     |  32   | [1-> [3,4], 2 -> [1,2]]  |
|b     |  16   | [1 -> [1,2]]             |
|b     |  18   | [2 -> [10,15]]           |
+------+-------+--------------------------+

I want to group the data based on the "letter" column so that the final dataset should have the below expected final output:

+------+---------------------------------+
|letter| maps                            |
+------+---------------------------------+
|a     | [1-> [4,6], 2 -> [1,2]]         |
|b     | [1-> [1,2], 2 -> [10,15]]       |                 
+------+---------------------------------+

I tried to group by "letter" and then apply an udf to aggregate the values in the map. Below is what I tried:

val aggregatedDs = SampleDataDs.groupBy("letter").agg(collect_list(SampleDataDs("maps")).alias("mapList")) 

Output:

+------+----------------------------------------+
|letter| mapList                                |
+------+-------+--------------------------------+
|a     | [[1-> [1,2]],[1-> [3,4], 2 -> [1,2]]]  |
|b     | [[1-> [1,2]],[2 -> [10,15]]]           |                 
+------+----------------------------------------+ 

After this I tried to write an udf to merge the output of collect_list and get the final output:

def mergeMap = udf { valSeq:Seq[Map[Int,SampleClass]] =>
valSeq.flatten.groupBy(_._1).mapValues(x=>(x.map(_._2.value1).reduce(_ + _),x.map(_._2.value2).reduce(_ + _)))
}

val aggMapDs = aggregatedDs.withColumn("aggValues",mergeMap(col("mapList")))

However it fails with the error message:

Failed to execute user defined function Caused by :java.lang.classCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to SampleClass

My Spark version is 2.3.1. Any ideas how I can get the expected final output?

Shaido
  • 27,497
  • 23
  • 70
  • 73
fresher
  • 23
  • 4

2 Answers2

3

The problem is due to the UDF not being able to accept the case class as input. Spark dataframes will internally represent your case class as a Row object. The problem can thus be avoided by changing the UDF input type as follows:

val mergeMap = udf((valSeq:Seq[Map[Int, Row]]) => {
  valSeq.flatten
    .groupBy(_._1)
    .mapValues(x=> 
      SampleClass(
        x.map(_._2.getAs[Int]("value1")).reduce(_ + _),
        x.map(_._2.getAs[Int]("value2")).reduce(_ + _)
      )
    )
})

Notice above that some minor additional changes are necessary to handle the Row object.

Running this code will result in:

val aggMapDs = aggregatedDs.withColumn("aggValues",mergeMap(col("mapList")))

+------+----------------------------------------------+-----------------------------+
|letter|mapList                                       |aggValues                    |
+------+----------------------------------------------+-----------------------------+
|b     |[Map(1 -> [1,2]), Map(2 -> [10,15])]          |Map(2 -> [10,15], 1 -> [1,2])|
|a     |[Map(1 -> [1,2]), Map(1 -> [3,4], 2 -> [1,2])]|Map(2 -> [1,2], 1 -> [4,6])  |
+------+----------------------------------------------+-----------------------------+
Shaido
  • 27,497
  • 23
  • 70
  • 73
2

There is a slight difference between Dataframe and Dataset.

Dataset takes on two distinct APIs characteristics: a strongly-typed API and an untyped API, as shown in the table below. Conceptually, consider DataFrame as an alias for a collection of generic objects Dataset[Row], where a Row is a generic untyped JVM object. Dataset, by contrast, is a collection of strongly-typed JVM objects, dictated by a case class you define in Scala or a class in Java

When you converting your Seq to Dataframe type information is lost.

val df: Dataframe = Seq(...).toDf() <-- here

What you could have done instead is convert Seq to Dataset

val typedDs: Dataset[(String, Int, Map[Int, SampleClass])] = Seq(...).toDS()

+---+---+--------------------+
| _1| _2|                  _3|
+---+---+--------------------+
|  a| 25|       [1 -> [1, 2]]|
|  a| 32|[1 -> [3, 4], 2 -...|
|  b| 16|       [1 -> [1, 2]]|
|  b| 18|     [2 -> [10, 15]]|
+---+---+--------------------+

Because your top-level object in the Seq is Tuple Spark generates dummy column names.

Now you should pay attention to the return type, there are functions on a typed Dataset that losing type information.

val untyped: Dataframe = typedDs
  .groupBy("_1")
  .agg(collect_list(typedDs("_3")).alias("mapList"))

In order to work with typed API you should explicitly define types

val aggregatedDs = sampleDataDs
      .groupBy("letter")
      .agg(collect_list(sampleDataDs("maps")).alias("mapList"))

val toTypedAgg: Dataset[(String, Array[Map[Int, SampleClass]])] = aggregatedDs
 .as[(String, Array[Map[Int, SampleClass]])] //<- here

Unfortunately, udf won't work as there are a limited number of types for which Spark can infer a schema.

toTypedAgg.withColumn("aggValues", mergeMap1(col("mapList"))).show()

Schema for type org.apache.spark.sql.Dataset[(String, Array[Map[Int,SampleClass]])] is not supported

What you could do instead is to map over a Dataset

val mapped = toTypedAgg.map(v => {
  (v._1, v._2.flatten.groupBy(_._1).mapValues(x=>(x.map(_._2.value1).sum,x.map(_._2.value2).sum)))
})

+---+----------------------------+
|_1 |_2                          |
+---+----------------------------+
|b  |[2 -> [10, 15], 1 -> [1, 2]]|
|a  |[2 -> [1, 2], 1 -> [4, 6]]  |
+---+----------------------------+
Gelerion
  • 1,634
  • 10
  • 17