0

I have a dataset like below; with values of col1 repeating multiple times and unique values of col2. This original dataset can almost a billion rows, so I do not want to use collect or collect_list as it will not scale-out for my use case.

Original Dataset:

+---------------------|
|    col1  |    col2  |
+---------------------|
|    AA|    11        |
|    BB|    21        |
|    AA|    12        |
|    AA|    13        |
|    BB|    22        |
|    CC|    33        |
+---------------------|

I want to transform the dataset into the following array format. newColumn as an array of col2.

Transformed Dataset:

+---------------------|
|col1  |     newColumn|
+---------------------|
|    AA|    [11,12,13]|
|    BB|    [21,22]   |
|    CC|    [33]      |
+---------------------|

I have seen this solution, but it uses collect_list and will not scale-out on big datasets.

Ani
  • 598
  • 5
  • 13
  • 29
  • I upvoted the question because this is more about performance not about the procedure itself. I wonder why this is closed – Raghu Jun 24 '20 at 16:48
  • Yeah, same here, It was closed automatically, maybe didnt read the whole questions and missed the performance aspect. – Ani Jun 24 '20 at 19:43

2 Answers2

1
  1. Load your dataframe
  2. Group by col1
  3. Aggregate col2 to a list using collect_list
import org.apache.spark.sql.functions

object GroupToArray {

  def main(args: Array[String]): Unit = {

    val spark = Constant.getSparkSess

    import spark.implicits._

    //Load your dataframe
    val df = List(("AA", "11"),
      ("BB", "21"),
      ("AA", "12"),
      ("AA", "13"),
      ("BB", "22"),
      ("CC", "33")).toDF("col1","col2")

    //Group by 'col1'
    df.groupBy("col1")
      //agregate on col2 and combine it to a list
    .agg(functions.collect_list("col2").as("newColumn"))
      .show()
  }

}
QuickSilver
  • 3,915
  • 2
  • 13
  • 29
1

Using the inbuilt functions of spark are always the best way. I see no problem in using the collect_list function. As long as you have sufficient memory, this would be the best way. One way of optimizing your job would be to save your data as parquet , bucket it by column A and saving it as a table. Better would be to also partition it by some column that evenly distributes data.

For example,

df_stored = #load your data from csv or parquet or any format'
spark.catalog.setCurrentDatabase(database_name)
df_stored.write.mode("overwrite").format("parquet").partitionBy(part_col).bucketBy(10,"col1").option("path",savepath).saveAsTable(tablename)
df_analysis = spark.table(tablename)
df_aggreg = df_analysis.groupby('col1').agg(F.collect_list(col('col2')))

This would speeden up the aggregation and avoid a lot of shuffle. try it out

Raghu
  • 1,644
  • 7
  • 19
  • Thanks, @Raghu, Yes, I am planning on saving the dataset using bucket and partition before performing groupBy. I will test with collect_list and will respond. I have had issues with collect_ist in the past, but it also depends on the implementation. In this scenario, I assume since I will be using collect_list in agg, size should never get too big. – Ani Jun 24 '20 at 04:16
  • super, let me know how it works out – Raghu Jun 24 '20 at 04:36
  • Hi @Raghu, I tested this will 200M records my app working as expected without any overload on memory. Thanks again for addressing my question from the performance perspective! – Ani Jun 26 '20 at 19:10
  • Happy to hear:-) – Raghu Jun 26 '20 at 22:10