45

I was wondering if there is some way to specify a custom aggregation function for spark dataframes over multiple columns.

I have a table like this of the type (name, item, price):

john | tomato | 1.99
john | carrot | 0.45
bill | apple  | 0.99
john | banana | 1.29
bill | taco   | 2.59

to:

I would like to aggregate the item and it's cost for each person into a list like this:

john | (tomato, 1.99), (carrot, 0.45), (banana, 1.29)
bill | (apple, 0.99), (taco, 2.59)

Is this possible in dataframes? I recently learned about collect_list but it appears to only work for one column.

zero323
  • 322,348
  • 103
  • 959
  • 935
anthonybell
  • 5,790
  • 7
  • 42
  • 60

5 Answers5

112

Consider using the struct function to group the columns together before collecting as a list:

import org.apache.spark.sql.functions.{collect_list, struct}
import sqlContext.implicits._

val df = Seq(
  ("john", "tomato", 1.99),
  ("john", "carrot", 0.45),
  ("bill", "apple", 0.99),
  ("john", "banana", 1.29),
  ("bill", "taco", 2.59)
).toDF("name", "food", "price")

df.groupBy($"name")
  .agg(collect_list(struct($"food", $"price")).as("foods"))
  .show(false)

Outputs:

+----+---------------------------------------------+
|name|foods                                        |
+----+---------------------------------------------+
|john|[[tomato,1.99], [carrot,0.45], [banana,1.29]]|
|bill|[[apple,0.99], [taco,2.59]]                  |
+----+---------------------------------------------+
Felix Leipold
  • 1,064
  • 10
  • 17
Daniel Siegmann
  • 1,221
  • 2
  • 8
  • 4
38

The easiest way to do this as a DataFrame is to first collect two lists, and then use a UDF to zip the two lists together. Something like:

import org.apache.spark.sql.functions.{collect_list, udf}
import sqlContext.implicits._

val zipper = udf[Seq[(String, Double)], Seq[String], Seq[Double]](_.zip(_))

val df = Seq(
  ("john", "tomato", 1.99),
  ("john", "carrot", 0.45),
  ("bill", "apple", 0.99),
  ("john", "banana", 1.29),
  ("bill", "taco", 2.59)
).toDF("name", "food", "price")

val df2 = df.groupBy("name").agg(
  collect_list(col("food")) as "food",
  collect_list(col("price")) as "price" 
).withColumn("food", zipper(col("food"), col("price"))).drop("price")

df2.show(false)
# +----+---------------------------------------------+
# |name|food                                         |
# +----+---------------------------------------------+
# |john|[[tomato,1.99], [carrot,0.45], [banana,1.29]]|
# |bill|[[apple,0.99], [taco,2.59]]                  |
# +----+---------------------------------------------+
David Griffin
  • 13,677
  • 5
  • 47
  • 65
  • 1
    I used `col(...)` instead of `$"..."` for a reason -- I find `col(...)` works with less work inside of things like `class` definitions. – David Griffin Jun 10 '16 at 11:58
  • Is there any function to realign columns like for example in the zip function tell it to first add an element from the tail of the column and remove one from the head and then zip them? In this case you can have for example next price for the items if you read prices daily and there is a time column. – M.Rez Sep 01 '16 at 14:24
  • Not entirely sure what you are asking. But you can use `DataFrame.select (...)` to change the order of columns. – David Griffin Sep 01 '16 at 15:22
  • I meant like this question: http://stackoverflow.com/q/39274585/2525128. I used this answer a lot on my code but I am trying to use the same method for time series data and add the next occurrence of an event for an specific observation as a field but I since I do not know exactly when that happens it is a little hard to make happen. – M.Rez Sep 01 '16 at 16:24
  • 13
    The answer assumes (maybe correctly) that collect_list() will preserve the order of elements on the two columns food & price. Meaning that food and price from the same row will end up at the same index in the two collected lists. Is this order preserving behavior guaranteed? (it would make sense, but I'm not sure by looking at the scala code for collect_list, not a scala programmer). – Kai Jan 11 '17 at 14:21
  • 3
    Afaik, there is no guarantee that the order of elements will be the same. cf : https://stackoverflow.com/questions/40407514/use-more-than-one-collect-list-in-one-query-in-spark-sql – Yann Moisan Oct 12 '17 at 11:44
  • 1
    I used a variation of this solution to zip five lists together. This gave me the opportunity to write the best line of code of my career so far: _ zip _ zip _ zip _ zip _ – Jeremy Apr 20 '18 at 20:39
  • Note: The function is non-deterministic because the order of collected results depends on order of rows which may be non-deterministic after a shuffle. https://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=dataframe#pyspark.sql.functions.collect_list – rjurney Nov 27 '18 at 01:14
10

Maybe a better way than the zip function (since UDF and UDAF are very bad to performance) is to wrap the two columns into Struct.

This would probably work as well:

df.select('name, struct('food, 'price).as("tuple"))
  .groupBy('name)
  .agg(collect_list('tuple).as("tuples"))
lzagkaretos
  • 2,842
  • 2
  • 16
  • 26
Yifan Guo
  • 146
  • 1
  • 5
6

To your point collect_list appears to only work for one column : For collect_list to work on multiple columns you will have to wrap the columns you want as aggregate in a struct. For e.g :

     val aggregatedData = df.groupBy("name").agg(collect_list(struct("item", "price")) as("food"))

     aggregatedData.show
+----+------------------------------------------------+
|name|foods                                           |
+----+------------------------------------------------+
|john|[[tomato, 1.99], [carrot, 0.45], [banana, 1.29]]|
|bill|[[apple, 0.99], [taco, 2.59]]                   |
+----+------------------------------------------------+
Neha Kumari
  • 757
  • 7
  • 16
2

Here is an option by converting the data frame to a RDD of Map and then call a groupByKey on it. The result would be a list of key-value pairs where value is a list of tuples.

df.show
+----+------+----+
|  _1|    _2|  _3|
+----+------+----+
|john|tomato|1.99|
|john|carrot|0.45|
|bill| apple|0.99|
|john|banana|1.29|
|bill|  taco|2.59|
+----+------+----+


val tuples = df.map(row => row(0) -> (row(1), row(2)))
tuples: org.apache.spark.rdd.RDD[(Any, (Any, Any))] = MapPartitionsRDD[102] at map at <console>:43

tuples.groupByKey().map{ case(x, y) => (x, y.toList) }.collect
res76: Array[(Any, List[(Any, Any)])] = Array((bill,List((apple,0.99), (taco,2.59))), (john,List((tomato,1.99), (carrot,0.45), (banana,1.29))))
Psidom
  • 209,562
  • 33
  • 339
  • 356