0

Suppose a given dataframe:

Model Color
Car Red
Car Red
Car Blue
Truck Red
Truck Blue
Truck Yellow
SUV Blue
SUV Blue
Car Blue
Car Yellow

I want to add color columns that keep a count of each color across each model to give the following dataframe:

Model Color Red Blue Yellow
Car Red 2 2 1
Car Red 2 2 1
Car Blue 2 2 1
Truck Red 1 1 1
Truck Blue 1 1 1
Truck Yellow 1 1 1
SUV Blue 0 2 0
SUV Blue 0 2 0
Car Blue 2 2 1
Car Yellow 2 2 1

This dataset has billions of records so I'm trying to stay away from UDF's and prefer to use built in methods if possible.

I normally use a window function with .size() and .collect_set() to count this type of data but adding multiple different new df columns based of different column categories is causing me issues because I'm not sure if I need to isolate the individual categories by adding additional window partitions or a .where() or isin() method with one window partition. Any feedback or recommendations are appreciated. Thank you.

2 Answers2

0

Lets do that with window functions and built-in PySpark DataFrame functions, window can be quite expensive computationally, especially with large data sets so maybe search a better way for your method. Dont forget to replace data and df with your actual data and DataFrame

from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import col, sum as _sum

spark = SparkSession.builder.getOrCreate()

data = [("Car","Red"), ("Car","Red"), ("Car","Blue"), ("Truck","Red"), 
        ("Truck","Blue"), ("Truck","Yellow"), ("SUV","Blue"), 
        ("SUV","Blue"), ("Car","Blue"), ("Car","Yellow")]

df = spark.createDataFrame(data, ["Model", "Color"])

window_model = Window.partitionBy('Model')

df = df.withColumn('Red', _sum((col('Color') == 'Red').cast('int')).over(window_model))
df = df.withColumn('Blue', _sum((col('Color') == 'Blue').cast('int')).over(window_model))
df = df.withColumn('Yellow', _sum((col('Color') == 'Yellow').cast('int')).over(window_model))

df.show()
Saxtheowl
  • 4,136
  • 5
  • 23
  • 32
0

If you don't care about preserving the original order, it can be done as a one-liner:

data = [("Car","Red"), ("Car","Red"), ("Car","Blue"), ("Truck","Red"), 
        ("Truck","Blue"), ("Truck","Yellow"), ("SUV","Blue"), 
        ("SUV","Blue"), ("Car","Blue"), ("Car","Yellow")]

df = spark.createDataFrame(data, ["Model", "Color"])

df.join(df.groupBy("Model").pivot("Color").count().fillna(0), on='Model').show()

# +-----+------+----+---+------+
# |Model| Color|Blue|Red|Yellow|
# +-----+------+----+---+------+
# |  SUV|  Blue|   2|  0|     0|
# |  SUV|  Blue|   2|  0|     0|
# |  Car|   Red|   2|  2|     1|
# |  Car|   Red|   2|  2|     1|
# |  Car|  Blue|   2|  2|     1|
# |  Car|  Blue|   2|  2|     1|
# |  Car|Yellow|   2|  2|     1|
# |Truck|   Red|   1|  1|     1|
# |Truck|  Blue|   1|  1|     1|
# |Truck|Yellow|   1|  1|     1|
# +-----+------+----+---+------+

If you care about the order, it can still be done as a one-liner:

import pyspark.sql.functions as F

df.withColumn("_id", F.monotonically_increasing_id())\
    .join(df.groupBy("Model").pivot("Color").count().fillna(0), on='Model')\
    .orderBy("_id").drop("_id").show()

# +-----+------+----+---+------+
# |Model| Color|Blue|Red|Yellow|
# +-----+------+----+---+------+
# |  Car|   Red|   2|  2|     1|
# |  Car|   Red|   2|  2|     1|
# |  Car|  Blue|   2|  2|     1|
# |Truck|   Red|   1|  1|     1|
# |Truck|  Blue|   1|  1|     1|
# |Truck|Yellow|   1|  1|     1|
# |  SUV|  Blue|   2|  0|     0|
# |  SUV|  Blue|   2|  0|     0|
# |  Car|  Blue|   2|  2|     1|
# |  Car|Yellow|   2|  2|     1|
# +-----+------+----+---+------+
Hristo Iliev
  • 72,659
  • 12
  • 135
  • 186