2

Problem Statement:

I have a dataframe with four columns: service (String), show (String), country_1 (Integer), & country_2 (Integer). My objective is to produce a dataframe that consists of just two columns: service (String) & information (Map[Integer, List[String]])

where the map could contain multiple records of key-value pairs like this per streaming service:

{
    "34521": ["The Crown", "Bridgerton", "The Queen's Gambit"],
    "49678": ["The Crown", "Bridgerton", "The Queen's Gambit"]
}

One important thing to note is that in the future, more countries can be added, for example another few columns in the input dataframe like "country_3", "country_4", etc. The objective with solution code is to also hopefully account for these things and not just hardcode selected columns like I had done in my attempted solution below, if that makes sense.

Input Dataframe:

Schema:

root
|-- service: string (nullable = true)
|-- show: string (nullable = true)
|-- country_1: integer (nullable = true)
|-- country_2: integer (nullable = true)

Dataframe:

service     |      show        |   country_1   |   country_2

Netflix      The Crown               34521           49678
Netflix      Bridgerton              34521           49678
Netflix      The Queen's Gambit      34521           49678
Peacock      The Office              34521           49678
Disney+      WandaVision             34521           49678 
Disney+      Marvel's 616            34521           49678
Disney+      The Mandalorian         34521           49678
Apple TV     Ted Lasso               34521           49678
Apple TV     The Morning Show        34521           49678

Output Dataframe:

Schema:

root
|-- service: string (nullable = true)
|-- information: map (nullable = false)
|    |-- key: integer
|    |-- value: array (valueContainsNull = true)
|    |    |-- element: string (containsNull = true)

Dataframe:

service    |  information          

Netflix    [34521 -> [The Crown, Bridgerton, The Queen’s Gambit], 49678 -> [The Crown, Bridgerton, The Queen’s Gambit]] 
Peacock    [34521 -> [The Office], 49678 -> [The Office]]
Disney+    [34521 -> [WandaVision, Marvel’s 616, The Mandalorian], 49678 -> [WandaVision, Marvel’s 616, The Mandalorian]]
Apple TV   [34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]]

What I have tried already

While I've successfully produced my desired output with the code snippet pasted, I don’t want to rely on using very basic SQL-type commands since I don't think it's always optimal for fast computations with large datasets, and additionally, I don’t want to rely on a method where I’m manually selecting the country columns by the exact name when mapping because that can always change in the sense that more country columns can be added later.

Is there a much better way of doing this that utilizes udfs, foldLeft, etc. type of code or anything else that helps with optimization and also helps the code be more concise and not as messy?

val df = spark.read.parquet("filepath/*.parquet") 
val temp = df.groupBy("service", "country_1", "country_2").agg(collect_list("show").alias("show"))
val service_information = grouped.withColumn("information", map(lit($"country_1"), $"show", lit($"country_2"), $"show")).drop("country_1", "country_2", "show")
corgi123
  • 23
  • 4
  • A couple of questions: (1) Does your `country_X` column always have the same country code for all rows (if not, your `groupBy` will result in multiple rows for some `service`)? (2) Could there be `null` in `country_X`, in which case your code would break (since `map` cannot take `null` keys)? – Leo C Feb 07 '21 at 06:27
  • (1) Yes the country_X column always has the same exact country code, so country_1 would have all same values, country_2 does, and if in the future, so would country_3 etc. and (2) there are no null values in country_x columns – corgi123 Feb 07 '21 at 20:38
  • Then, wouldn't every `show` be tied to the exact same list of countries? In that case, wouldn't having a single list of countries suffice? – Leo C Feb 07 '21 at 21:06
  • I'm not entirely sure if I understand your comment, but regarding having an input of a list of countries instead of separate columns for each country, that is out of my control as that is how the data I am receiving is being passed in. I specifically am trying to end up with a Map of multiple records, in which each separate country is a key of one record (key-value pair), and each show correlates to that map. The output schema is what I need it to be in order to fit how the rest of my use case is being handled. – corgi123 Feb 10 '21 at 02:27
  • It's easier to illustrate with sample code. Please see my answer. – Leo C Feb 10 '21 at 04:40

1 Answers1

1

As per the country data "specs" described in the comments section (i.e. country code will be identical and non-null in all rows for any given country_X column), your code can be generalized to handle arbitrarily many country columns:

val df = Seq(
  ("Netflix",     "The Crown",             34521,    49678),
  ("Netflix",     "Bridgerton",            34521,    49678),
  ("Netflix",     "The Queen's Gambit",    34521,    49678),
  ("Peacock",     "The Office",            34521,    49678),
  ("Disney+",     "WandaVision",           34521,    49678),
  ("Disney+",     "Marvel's 616",          34521,    49678),
  ("Disney+",     "The Mandalorian",       34521,    49678),
  ("Apple TV",    "Ted Lasso",             34521,    49678),
  ("Apple TV",    "The Morning Show",      34521,    49678)
).toDF("service", "show", "country_1", "country_2")

val countryCols = df.columns.filter(_.startsWith("country_")).toList

val grouped = df.groupBy("service", countryCols: _*).agg(collect_list("show").as("shows"))

val service_information = grouped.withColumn(
    "information",
    map( countryCols.flatMap{ c => col(c) :: col("shows") :: Nil }: _* )
  ).drop("shows" :: countryCols: _*)

service_information.show(false)
// +--------+--------------------------------------------------------------------------------------------------------------+
// |service |information                                                                                                   |
// +--------+--------------------------------------------------------------------------------------------------------------+
// |Disney+ |[34521 -> [WandaVision, Marvel's 616, The Mandalorian], 49678 -> [WandaVision, Marvel's 616, The Mandalorian]]|
// |Peacock |[34521 -> [The Office], 49678 -> [The Office]]                                                                |
// |Netflix |[34521 -> [The Crown, Bridgerton, The Queen's Gambit], 49678 -> [The Crown, Bridgerton, The Queen's Gambit]]  |
// |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]]                              |
// +--------+--------------------------------------------------------------------------------------------------------------+

Note that the described country "specs" would mandate all shows to be associated with the same list of countries. For instance, if you have 3 country_Xs columns and every row of a given country_X is identical without nulls, that means every show is tied to those 3 countries. What if you have a show available only for 2 of the 3 countries?


In case your data schema could be revised, a more flexible way of maintaining the associated country info would be to have a single ArrayType column for every show.

val df = Seq(
  ("Netflix",     "The Crown",             Seq(34521, 49678)),
  ("Netflix",     "Bridgerton",            Seq(34521)),
  ("Netflix",     "The Queen's Gambit",    Seq(10001, 49678)),
  ("Peacock",     "The Office",            Seq(34521, 49678)),
  ("Disney+",     "WandaVision",           Seq(10001, 20002, 34521)),
  ("Disney+",     "Marvel's 616",          Seq(49678)),
  ("Disney+",     "The Mandalorian",       Seq(34521, 49678)),
  ("Apple TV",    "Ted Lasso",             Seq(34521, 49678)),
  ("Apple TV",    "The Morning Show",      Seq(20002, 34521))
).toDF("service", "show", "countries")

val grouped = df.withColumn("country", explode($"countries")).
  groupBy("service", "country").agg(collect_list($"show").as("shows"))

val service_information = grouped.groupBy("service").
  agg(collect_list($"country").as("c_list"), collect_list($"shows").as("s_list")).
  select($"service", map_from_arrays($"c_list", $"s_list").as("information"))

service_information.show(false)
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
// |service |information                                                                                                                        |
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
// |Peacock |[34521 -> [The Office], 49678 -> [The Office]]                                                                                     |
// |Disney+ |[20002 -> [WandaVision], 49678 -> [Marvel's 616, The Mandalorian], 34521 -> [WandaVision, The Mandalorian], 10001 -> [WandaVision]]|
// |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso], 20002 -> [The Morning Show]]                                        |
// |Netflix |[49678 -> [The Crown, The Queen's Gambit], 10001 -> [The Queen's Gambit], 34521 -> [The Crown, Bridgerton]]                        |
// +--------+-----------------------------------------------------------------------------------------------------------------------------------+
Leo C
  • 22,006
  • 3
  • 26
  • 39
  • Thank you so much! Can you help explain the purpose of using Nil, or maybe even in a broader sense, what your code is doing to create val service_information in the first implementation you provided? – corgi123 Feb 10 '21 at 05:09
  • @corgi123, `col(c) :: col("show") :: Nil` is just one way of expressing `List(col(c), col("show"))`, allowing `flatMap` to pair up every `country` with `show` row-wise as parameters for method [map](https://spark.apache.org/docs/2.4.0/api/scala/index.html#org.apache.spark.sql.functions$@map(cols:org.apache.spark.sql.Column*):org.apache.spark.sql.Column). – Leo C Feb 10 '21 at 05:32
  • Makes sense I think, thanks! With regards to the second implementation, that's a good way to approach if in the future, the data I receive is more varied between countries and shows and not as uniform as it is now. – corgi123 Feb 10 '21 at 06:01