1

I have a Spark Dataset of the format -

+--------------+--------+-----+
|name          |type    |cost |
+--------------+--------+-----+
|AAAAAAAAAAAAAA|XXXXX   |0.24|
|AAAAAAAAAAAAAA|YYYYY   |1.14|
|BBBBBBBBBBBBBB|XXXXX   |0.78|
|BBBBBBBBBBBBBB|YYYYY   |2.67|
|BBBBBBBBBBBBBB|ZZZZZ   |0.15|
|CCCCCCCCCCCCCC|XXXXX   |1.86|
|CCCCCCCCCCCCCC|YYYYY   |1.50|
|CCCCCCCCCCCCCC|ZZZZZ   |1.00|
+--------------+--------+----+

I want to transform this into an object of type -

public class CostPerName {
    private String name;
    private Map<String, Double> costTypeMap;
}

What I want is,

+--------------+-----------------------------------------------+
|name          |           typeCost.                           |
+--------------+-----------------------------------------------+
|AAAAAAAAAAAAAA|(XXXXX, 0.24), (YYYYY, 1.14)                   |            
|BBBBBBBBBBBBBB|(XXXXX, 0.78), (YYYYY, 2.67), (ZZZZZ, 0.15)    |
|CCCCCCCCCCCCCC|(XXXXX, 1.86), (YYYYY, 1.50), (ZZZZZ, 1.00)    |
+--------------+-----------------------------------------------+

i.e., for each name, I want to a map of (type, cost).

What is an efficient way to achieve this transformation? Can I use some dataFrame transformation? I tried groupBy but that will only work if I am performing aggregate queries like sum, avg etc.

mazaneicha
  • 8,794
  • 4
  • 33
  • 52
maddie
  • 629
  • 10
  • 29

2 Answers2

5

You can combine the two columns type and cost into a new struct column, then group by name and use collect_list as aggregation function:

df.withColumn("type_cost", struct("type", "cost"))
     .groupBy("name").agg(collect_list("type_cost"))

This will result in a dataframe like this:

+--------------+---------------------------------------------+
|name          |collect_list(type_cost)                      |
+--------------+---------------------------------------------+
|AAAAAAAAAAAAAA|[[XXXXX, 0.24], [YYYYY, 1.14]]               |
|CCCCCCCCCCCCCC|[[XXXXX, 1.86], [YYYYY, 1.50], [ZZZZZ, 1.00]]|
|BBBBBBBBBBBBBB|[[XXXXX, 0.78], [YYYYY, 2.67], [ZZZZZ, 0.15]]|
+--------------+---------------------------------------------+
werner
  • 13,518
  • 6
  • 30
  • 45
  • Thanks, this works perfectly. So, does the answer by @mazaneicha. Just so it helps me and other folks who stumble here, could you tell me how you approach understanding spark sql in depth? I am asking because I did go through the spark docs but could not think of this approach. – maddie Jun 10 '20 at 19:37
  • please note that using this solution `typeCost` becomes a list not a map. Other than that, great answer! – mazaneicha Jun 10 '20 at 19:39
  • @mazaneicha you are right. For Spark versions >= 2.4 your answer is closer to the question – werner Jun 10 '20 at 19:44
  • 1
    @maddie try to run a printSchema after the first step. Then you can see that the new column type_cost is a struct. This struct is then collected within the aggregation – werner Jun 10 '20 at 20:05
3

You can use a map_from_arrays() if your Spark version allows it:

scala> val df2 = df.groupBy("name").agg(map_from_arrays(collect_list($"type"), collect_list($"cost")).as("typeCost"))
df2: org.apache.spark.sql.DataFrame = [name: string, typeCost: map<string,decimal(3,2)>]

scala> df2.printSchema()
root
 |-- name: string (nullable = false)
 |-- typeCost: map (nullable = true)
 |    |-- key: string
 |    |-- value: decimal(3,2) (valueContainsNull = true)

scala> df2.show(false)
+--------------+---------------------------------------------+
|name          |typeCost                                     |
+--------------+---------------------------------------------+
|AAAAAAAAAAAAAA|[XXXXX -> 0.24, YYYYY -> 1.14]               |
|CCCCCCCCCCCCCC|[XXXXX -> 1.86, YYYYY -> 1.50, ZZZZZ -> 1.00]|
|BBBBBBBBBBBBBB|[XXXXX -> 0.78, YYYYY -> 2.67, ZZZZZ -> 0.15]|
+--------------+---------------------------------------------+

scala>
mazaneicha
  • 8,794
  • 4
  • 33
  • 52