For a simple problem like this, you could also use the explode function. I don't know the performance characteristics versus the selected udf answer though.
from pyspark.sql import functions as F
df = sc.parallelize([(1, [1, 2, 3]), (1, [4, 5, 6]) , (2,[2]),(2,[3])]).toDF(['store', 'values'])
df2 = df.withColumn('values', F.explode('values'))
# +-----+------+
# |store|values|
# +-----+------+
# | 1| 1|
# | 1| 2|
# | 1| 3|
# | 1| 4|
# | 1| 5|
# | 1| 6|
# | 2| 2|
# | 2| 3|
# +-----+------+
df3 = df2.groupBy('store').agg(F.collect_list('values').alias('values'))
# +-----+------------------+
# |store| values |
# +-----+------------------+
# |1 |[4, 5, 6, 1, 2, 3]|
# |2 |[2, 3] |
# +-----+------------------+
Note: you could use F.collect_set()
in the aggregation or .drop_duplicates()
on df2 to remove duplicate values.
If you want to maintain ordered values in the collected list, I found the following method in another SO answer:
from pyspark.sql.window import Window
w = Window.partitionBy('store').orderBy('values')
df3 = df2.withColumn('ordered_value_lists', F.collect_list('values').over(w))
# +-----+------+-------------------+
# |store|values|ordered_value_lists|
# +-----+------+-------------------+
# |1 |1 |[1] |
# |1 |2 |[1, 2] |
# |1 |3 |[1, 2, 3] |
# |1 |4 |[1, 2, 3, 4] |
# |1 |5 |[1, 2, 3, 4, 5] |
# |1 |6 |[1, 2, 3, 4, 5, 6] |
# |2 |2 |[2] |
# |2 |3 |[2, 3] |
# +-----+------+-------------------+
df4 = df3.groupBy('store').agg(F.max('ordered_value_lists').alias('values'))
df4.show(truncate=False)
# +-----+------------------+
# |store|values |
# +-----+------------------+
# |1 |[1, 2, 3, 4, 5, 6]|
# |2 |[2, 3] |
# +-----+------------------+
If the values themselves don't determine the order, you can use F.posexplode()
and use the 'pos'
column in your window functions instead of 'values'
to determine order. Note: you will also need a higher level order column to order the original arrays, then use the position in the array to order the elements of the array.
df = sc.parallelize([(1, [1, 2, 3], 1), (1, [4, 5, 6], 2) , (2, [2], 1),(2, [3], 2)]).toDF(['store', 'values', 'array_order'])
# +-----+---------+-----------+
# |store|values |array_order|
# +-----+---------+-----------+
# |1 |[1, 2, 3]|1 |
# |1 |[4, 5, 6]|2 |
# |2 |[2] |1 |
# |2 |[3] |2 |
# +-----+---------+-----------+
df2 = df.select('*', F.posexplode('values'))
# +-----+---------+-----------+---+---+
# |store|values |array_order|pos|col|
# +-----+---------+-----------+---+---+
# |1 |[1, 2, 3]|1 |0 |1 |
# |1 |[1, 2, 3]|1 |1 |2 |
# |1 |[1, 2, 3]|1 |2 |3 |
# |1 |[4, 5, 6]|2 |0 |4 |
# |1 |[4, 5, 6]|2 |1 |5 |
# |1 |[4, 5, 6]|2 |2 |6 |
# |2 |[2] |1 |0 |2 |
# |2 |[3] |2 |0 |3 |
# +-----+---------+-----------+---+---+
w = Window.partitionBy('store').orderBy('array_order', 'pos')
df3 = df2.withColumn('ordered_value_lists', F.collect_list('col').over(w))
# +-----+---------+-----------+---+---+-------------------+
# |store|values |array_order|pos|col|ordered_value_lists|
# +-----+---------+-----------+---+---+-------------------+
# |1 |[1, 2, 3]|1 |0 |1 |[1] |
# |1 |[1, 2, 3]|1 |1 |2 |[1, 2] |
# |1 |[1, 2, 3]|1 |2 |3 |[1, 2, 3] |
# |1 |[4, 5, 6]|2 |0 |4 |[1, 2, 3, 4] |
# |1 |[4, 5, 6]|2 |1 |5 |[1, 2, 3, 4, 5] |
# |1 |[4, 5, 6]|2 |2 |6 |[1, 2, 3, 4, 5, 6] |
# |2 |[2] |1 |0 |2 |[2] |
# |2 |[3] |2 |0 |3 |[2, 3] |
# +-----+---------+-----------+---+---+-------------------+
df4 = df3.groupBy('store').agg(F.max('ordered_value_lists').alias('values'))
# +-----+------------------+
# |store|values |
# +-----+------------------+
# |1 |[1, 2, 3, 4, 5, 6]|
# |2 |[2, 3] |
# +-----+------------------+
Edit: If you'd like to keep some columns along for the ride and they don't need to be aggregated, you can include them in the groupBy
or rejoin them after aggregation (examples below). If they do require aggregation, only group by 'store'
and just add whatever aggregation function you need on the 'other'
column/s to the .agg()
call.
from pyspark.sql import functions as F
df = sc.parallelize([(1, [1, 2, 3], 'a'), (1, [4, 5, 6], 'a') , (2, [2], 'b'), (2, [3], 'b')]).toDF(['store', 'values', 'other'])
# +-----+---------+-----+
# |store| values|other|
# +-----+---------+-----+
# | 1|[1, 2, 3]| a|
# | 1|[4, 5, 6]| a|
# | 2| [2]| b|
# | 2| [3]| b|
# +-----+---------+-----+
df2 = df.withColumn('values', F.explode('values'))
# +-----+------+-----+
# |store|values|other|
# +-----+------+-----+
# | 1| 1| a|
# | 1| 2| a|
# | 1| 3| a|
# | 1| 4| a|
# | 1| 5| a|
# | 1| 6| a|
# | 2| 2| b|
# | 2| 3| b|
# +-----+------+-----+
df3 = df2.groupBy('store', 'other').agg(F.collect_list('values').alias('values'))
# +-----+-----+------------------+
# |store|other| values|
# +-----+-----+------------------+
# | 1| a|[1, 2, 3, 4, 5, 6]|
# | 2| b| [2, 3]|
# +-----+-----+------------------+
df4 = (
df.drop('values')
.join(
df2.groupBy('store')
.agg(F.collect_list('values').alias('values')),
on=['store'], how='inner'
)
.drop_duplicates()
)
# +-----+-----+------------------+
# |store|other| values|
# +-----+-----+------------------+
# | 1| a|[1, 2, 3, 4, 5, 6]|
# | 2| b| [2, 3]|
# +-----+-----+------------------+