0

I have the following dataframe-

>>> my_df.show(3)
+------------+---------+-------+--------------+
|     user_id|  address|   type|count| country|
+------------+---------+-------+-----+--------+
|      ABC123|  yyy,USA| animal|    2|     USA|
|      ABC123|  xxx,USA| animal|    3|     USA|
|      qwerty|  55A,AUS|  human|    3|     AUS|
|      ABC123|  zzz,RSA| animal|    4|     RSA|
+------------+---------+-------+--------------+

How do I roll-up this dataframe to get the following result-

>>> new_df.show(3)
+------------+---------+-------+--------------+
|     user_id|  address|   type|count| country|
+------------+---------+-------+-----+--------+
|      qwerty|  55A,AUS|  human|    3|     AUS|
|      ABC123|  xxx,USA| animal|    5|     USA|
+------------+---------+-------+--------------+ 

For a given user_id:

  1. Get the country with the highest sum of counts
  2. For the country got in step 1, get the address with the highest count

I'm guessing I'll have to split my_df into 2 different dataframes and get the country and address separately. But I don't exactly know the syntax for that. Your help is appreciated. Thanks.

kev
  • 2,741
  • 5
  • 22
  • 48

1 Answers1

1

I meant something like this:

>>> import pandas as pd
>>> from pyspark.sql.functions import *
>>> from pyspark.sql.window import *
>>> from pyspark.sql import SparkSession

>>> spark = SparkSession.builder.appName('abc').getOrCreate()

>>> data = {"user_id": ["ABC123", "ABC123", "qwerty", "ABC123"], "address": ["yyy,USA", "xxx,USA", "55A,AUS", "zzz,RSA"], "type": ["animal", "animal", "human", "animal"], "count": [2,3,3,4], "country": ["USA", "USA", "AUS", "RSA"]}

>>> df = pd.DataFrame(data=data)

>>> df_pyspark = spark.createDataFrame(df)

>>> w = Window().partitionBy("user_id", "country").orderBy((col("count").desc()))

>>> w2 = Window().partitionBy("user_id").orderBy(col("sum_country").desc())

>>> df_pyspark.select("user_id", "address", "type", "count", "country", sum("count").over(w).alias("sum_country")).select("user_id", first("country").over(w2).alias("top_country"), first("address").over(w).alias("top_address"), "country").where(col("top_country")==col("country")).distinct().show()
+-------+-----------+-----------+-------+
|user_id|top_country|top_address|country|
+-------+-----------+-----------+-------+
| qwerty|        AUS|    55A,AUS|    AUS|
| ABC123|        USA|    xxx,USA|    USA|
+-------+-----------+-----------+-------+

You may add type, count, etc. depending on which logic you would like to use to do that - you can either do the same as for top_address (i.e. first function), or you can groupBy and agg

kev
  • 2,741
  • 5
  • 22
  • 48
Grzegorz Skibinski
  • 12,624
  • 2
  • 11
  • 34
  • this works good. Although, the `df_pyspark.count()` when I added the `type` and `count` columns to it is different from when I did not have those columns. I did `df_pyspark.select("user_id", "address", "type", "count", "country", sum("count").over(w).alias("sum_country")).select("user_id", first("country").over(w2).alias("top_country"), first("address").over(w).alias("top_address"), "country", "type", col("sum_country").alias("count")).where(col("top_country")==col("country")).distinct().show()`. What do I do? – kev Sep 24 '19 at 15:37
  • Also, `type` will always remain the same for a given `user_id`. – kev Sep 24 '19 at 15:58
  • Hm, this should work fine then. So the above, if you replace ```.show()``` with ```.count()``` should return 2, does it show different for you? – Grzegorz Skibinski Sep 24 '19 at 19:45
  • i didn't get you. Btw, i just checked the data - and there seems to be duplicate rows because there's **multiple different counts** for the same combination of a given **user_id, address, type, country**. How did this happen? Should I just do a `groupBy("user_id", "address", "type", "country").sum()` on the final `df_pyspark`? – kev Sep 25 '19 at 14:17
  • Yes, because ```distinct()``` only removes duplicates per row (for all columns), so you can either incorporate ```count``` into windowing function analogically to e.g. ```address``` so: ```first("count").over(w).alias("top_count")```, so make it the same as per the window. Or you can do ```groupBy``` ```agg``` and take e.g. max. from ```count``` as per: https://spark.apache.org/docs/2.1.0/api/python/pyspark.sql.html#pyspark.sql.DataFrame.groupBy – Grzegorz Skibinski Sep 25 '19 at 14:21