7
  aggregrated_table = df_input.groupBy('city', 'income_bracket') \
        .agg(
       count('suburb').alias('suburb'),
       sum('population').alias('population'),
       sum('gross_income').alias('gross_income'),
       sum('no_households').alias('no_households'))

Would like to group by city and income bracket but within each city certain suburbs have different income brackets. How do I group by the most frequently occurring income bracket per city?

for example:

city1 suburb1 income_bracket_10 
city1 suburb1 income_bracket_10 
city1 suburb2 income_bracket_10 
city1 suburb3 income_bracket_11 
city1 suburb4 income_bracket_10 

Would be grouped by income_bracket_10

Jess
  • 97
  • 1
  • 2
  • 12

4 Answers4

9

Using a window function before aggregating might do the trick:

from pyspark.sql import Window
import pyspark.sql.functions as psf

w = Window.partitionBy('city')
aggregrated_table = df_input.withColumn(
    "count", 
    psf.count("*").over(w)
).withColumn(
    "rn", 
    psf.row_number().over(w.orderBy(psf.desc("count")))
).filter("rn = 1").groupBy('city', 'income_bracket').agg(
   psf.count('suburb').alias('suburb'),
   psf.sum('population').alias('population'),
   psf.sum('gross_income').alias('gross_income'),
   psf.sum('no_households').alias('no_households'))

you can also use a window function after aggregating since you're keeping a count of (city, income_bracket) occurrences.

MaFF
  • 9,551
  • 2
  • 32
  • 41
  • Perfect - thanks! I did have some issues with null values that take precedence over actual values, but used your solution in combination with https://stackoverflow.com/questions/35142216/first-value-windowing-function-in-pyspark and it works! – Jess Aug 21 '17 at 13:27
5

You don't necessarily need Window functions:

aggregrated_table = (
    df_input.groupby("city", "suburb","income_bracket")
    .count()
    .withColumn("count_income", F.array("count", "income_bracket"))
    .groupby("city", "suburb")
    .agg(F.max("count_income").getItem(1).alias("most_common_income_bracket"))
) 

I think this does what you require. I don't really know if it performs better than the window based solution.

mfcabrera
  • 781
  • 10
  • 26
  • 1
    The solution by mfcabrera is better for very large datasets where you wont force the entire dataset into a single node. – thentangler May 21 '20 at 02:47
3

For pyspark version >=3.4 you can use the mode function directly to get the most frequent element per group:

from pyspark.sql import functions as f

df = spark.createDataFrame([
...     ("Java", 2012, 20000), ("dotNET", 2012, 5000),
...     ("Java", 2012, 20000), ("dotNET", 2012, 5000),
...     ("dotNET", 2013, 48000), ("Java", 2013, 30000)],
...     schema=("course", "year", "earnings"))
>>> df.groupby("course").agg(f.mode("year")).show()
+------+----------+
|course|mode(year)|
+------+----------+
|  Java|      2012|
|dotNET|      2012|
+------+----------+

https://github.com/apache/spark/blob/7f1b6fe02bdb2c68d5fb3129684ca0ed2ae5b534/python/pyspark/sql/functions.py#L379

Jan_ewazz
  • 361
  • 4
  • 3
0

The solution by mfcabrera gave wrong results when F.max was used on F.array column as the values in ArrayType are treated as String and integer max didnt work as expected.

The below solution worked.

w = Window.partitionBy('city', "suburb").orderBy(f.desc("count"))

aggregrated_table = (
    input_df.groupby("city", "suburb","income_bracket")
    .count()
    
    .withColumn("max_income", f.row_number().over(w2))
    .filter(f.col("max_income") == 1).drop("max_income")
) 
aggregrated_table.display()