1

I have a dataframe such as follows. I would like to group by device and order by start_time within each group. Then, for each row in the group, get the most frequently occurring station from a window of 3 rows before it (including itself).

columns = ['device', 'start_time', 'station']
data = [("Python", 1, "station_1"), ("Python", 2, "station_2"), ("Python", 3, "station_1"), ("Python", 4, "station_2"), ("Python", 5, "station_2"), ("Python", 6, None)]


test_df = spark.createDataFrame(data).toDF(*columns)
rolling_w = Window.partitionBy('device').orderBy('start_time').rowsBetween(-2, 0)

Desired output:

+------+----------+---------+--------------------+
|device|start_time|  station|rolling_mode_station|
+------+----------+---------+--------------------+
|Python|         1|station_1|           station_1|
|Python|         2|station_2|           station_2|
|Python|         3|station_1|           station_1|
|Python|         4|station_2|           station_2|
|Python|         5|station_2|           station_2|
|Python|         6|     null|           station_2|
+------+----------+---------+--------------------+

Since Pyspark does not have a mode() function, I know how to get the most frequent value in a static groupby as shown here, but I don't know how to adapt it to a rolling window.

blackbishop
  • 30,945
  • 11
  • 55
  • 76

2 Answers2

6

You can use collect_list function to get the stations from last 3 rows using the defined window, then for each resulting array calculate the most frequent element.

To get the most frequent element on the array, you can explode it then group by and count as in linked post your already saw or use some UDF like this:

import pyspark.sql.functions as F

test_df.withColumn(
    "rolling_mode_station",
    F.collect_list("station").over(rolling_w)
).withColumn(
    "rolling_mode_station",
    F.udf(lambda x: max(set(x), key=x.count))(F.col("rolling_mode_station"))
).show()

#+------+----------+---------+--------------------+
#|device|start_time|  station|rolling_mode_station|
#+------+----------+---------+--------------------+
#|Python|         1|station_1|           station_1|
#|Python|         2|station_2|           station_1|
#|Python|         3|station_1|           station_1|
#|Python|         4|station_2|           station_2|
#|Python|         5|station_2|           station_2|
#|Python|         6|     null|           station_2|
#+------+----------+---------+--------------------+
blackbishop
  • 30,945
  • 11
  • 55
  • 76
1

I had a similar requirements and this is how I achieved this.

Step 1: Create a UDF for most common element in an array:

import pyspark.sql.functions as F
@F.udf
def mode(x):
    from collections import Counter
    return Counter(x).most_common(1)[0][0]

Step 2: Window function

test_df_tmp=test_df.withColumn(
    "rolling_mode_station",
    F.collect_list("station").over(rolling_w)
)
test_df_tmp.show(truncate=False)

Step3: Call UDF created in step 1

test_df_tmp.select('device','start_time','station', mode('rolling_mode_station')).show()