Problem Statement:
I have a dataframe with four columns: service (String), show (String), country_1 (Integer), & country_2 (Integer). My objective is to produce a dataframe that consists of just two columns: service (String) & information (Map[Integer, List[String]])
where the map could contain multiple records of key-value pairs like this per streaming service:
{
"34521": ["The Crown", "Bridgerton", "The Queen's Gambit"],
"49678": ["The Crown", "Bridgerton", "The Queen's Gambit"]
}
One important thing to note is that in the future, more countries can be added, for example another few columns in the input dataframe like "country_3", "country_4", etc. The objective with solution code is to also hopefully account for these things and not just hardcode selected columns like I had done in my attempted solution below, if that makes sense.
Input Dataframe:
Schema:
root
|-- service: string (nullable = true)
|-- show: string (nullable = true)
|-- country_1: integer (nullable = true)
|-- country_2: integer (nullable = true)
Dataframe:
service | show | country_1 | country_2
Netflix The Crown 34521 49678
Netflix Bridgerton 34521 49678
Netflix The Queen's Gambit 34521 49678
Peacock The Office 34521 49678
Disney+ WandaVision 34521 49678
Disney+ Marvel's 616 34521 49678
Disney+ The Mandalorian 34521 49678
Apple TV Ted Lasso 34521 49678
Apple TV The Morning Show 34521 49678
Output Dataframe:
Schema:
root
|-- service: string (nullable = true)
|-- information: map (nullable = false)
| |-- key: integer
| |-- value: array (valueContainsNull = true)
| | |-- element: string (containsNull = true)
Dataframe:
service | information
Netflix [34521 -> [The Crown, Bridgerton, The Queen’s Gambit], 49678 -> [The Crown, Bridgerton, The Queen’s Gambit]]
Peacock [34521 -> [The Office], 49678 -> [The Office]]
Disney+ [34521 -> [WandaVision, Marvel’s 616, The Mandalorian], 49678 -> [WandaVision, Marvel’s 616, The Mandalorian]]
Apple TV [34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]]
What I have tried already
While I've successfully produced my desired output with the code snippet pasted, I don’t want to rely on using very basic SQL-type commands since I don't think it's always optimal for fast computations with large datasets, and additionally, I don’t want to rely on a method where I’m manually selecting the country columns by the exact name when mapping because that can always change in the sense that more country columns can be added later.
Is there a much better way of doing this that utilizes udfs, foldLeft, etc. type of code or anything else that helps with optimization and also helps the code be more concise and not as messy?
val df = spark.read.parquet("filepath/*.parquet")
val temp = df.groupBy("service", "country_1", "country_2").agg(collect_list("show").alias("show"))
val service_information = grouped.withColumn("information", map(lit($"country_1"), $"show", lit($"country_2"), $"show")).drop("country_1", "country_2", "show")