I am trying to generate window aggregates for my data. However, it is taking too much time for lags > 20. I am running it in Databricks.
My data has columns: userid, date, orders, total_spend
+------+----------+------+-----------+
|userid|date |orders|total_spend|
+------+----------+------+-----------+
|1 |2022-05-01| 2 | 1000 |
|1 |2022-05-02| 3 | 2000 |
|2 |2022-05-01| 1 | 2000 |
|3 |2022-05-01| 2 | 2000 |
|3 |2022-05-02| 4 | 3000 |
|4 |2022-05-01| 1 | 400 |
|5 |2022-05-01| 2 | 2000 |
|5 |2022-05-02| 4 | 1500 |
|5 |2022-05-02| 2 | 6000 |
from pyspark.sql import functions as F
def getWindow(lag):
return F.window(
F.col("date"),
windowDuration=f"{lag} days",
slideDuration="1 days",
).alias("window")
def getAggregated(df, window, column, lag):
return (
df
.groupBy(F.col("userid"), window)
.agg(
F.avg(F.col(column)).alias(f"mean_{column}_last{lag}days"),
F.sum(F.col(column)).alias(f"sum_{column}_last{lag}days")
)
.withColumn("date", F.date_sub(F.col("window.end").cast("date"), 0))
.drop("window")
)
LAGS = [1, 3, 10, 20, 40, 80, 180]
COLUMNS_TO_BE_AGGREGATED = [
"orders",
"total_spend"
]
df = spark.read.parquet("df_location")
df = df.orderBy("userid", "date")
df.persist()
for col in COLUMNS_TO_BE_AGGREGATED:
for lag in LAGS:
window = getWindow(lag)
agg_df = getAggregated(df, window, col, lag)
df = df.join(agg_df, ["userid", "date"], how="left")
Is there something I am doing incorrectly? Any suggestions on how do I optimize it?