0

I have the following DF:

|-----------------------|
|Date       | Val | Cond|
|-----------------------|
|2022-01-08 | 2   | 0   |
|2022-01-09 | 4   | 1   |
|2022-01-10 | 6   | 1   |
|2022-01-11 | 8   | 0   |
|2022-01-12 | 2   | 1   |
|2022-01-13 | 5   | 1   |
|2022-01-14 | 7   | 0   |
|2022-01-15 | 9   | 0   | 
|-----------------------|

I need to sum the values of two days before where cond = 1 for every date, my expected output is:

|-----------------|
|Date       | Sum |
|-----------------|
|2022-01-08 | 0   |  Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-09 | 0   |  Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-10 | 0   |  Not sum because doesnt exists two dates with cond = 1 before this date
|2022-01-11 | 10  | (4+6)
|2022-01-12 | 10  | (4+6)
|2022-01-13 | 8   | (2+6)
|2022-01-14 | 7   | (5+2)
|2022-01-15 | 7   | (5+2)
|-----------------|

I've tried to get the output DF using this code:

df = df.where("Cond= 1").withColumn(
    "ListView",
    f.collect_list("Val").over(windowSpec.rowsBetween(-2, -1))
)

But when I use .where("Cond = 1") I exclude the dates that cond is equal zero.

I found the following answer but didn't help me:

Window.rowsBetween - only consider rows fulfilling a specific condition (e.g. not being null)

How can I achieve my expected output using window functions?

The MVCE:

data_1=[
    ("2022-01-08",2,0),
    ("2022-01-09",4,1),
    ("2022-01-10",6,1),
    ("2022-01-11",8,0),
    ("2022-01-12",2,1),
    ("2022-01-13",5,1),
    ("2022-01-14",7,0),
    ("2022-01-15",9,0) 
]

schema_1 = StructType([
    StructField("Date", DateType(),True),
    StructField("Val", IntegerType(),True),
    StructField("Cond", IntegerType(),True)
  ])

df_1 = spark.createDataFrame(data=data_1,schema=schema_1)
OdiumPura
  • 444
  • 5
  • 25

1 Answers1

1

The following should do the trick (but I'm sure it can be further optimized).

Setup:

data_1=[
    ("2022-01-08",2,0),
    ("2022-01-09",4,1),
    ("2022-01-10",6,1),
    ("2022-01-11",8,0),
    ("2022-01-12",2,1),
    ("2022-01-13",5,1),
    ("2022-01-14",7,0),
    ("2022-01-15",9,0),
    ("2022-01-16",9,0),
    ("2022-01-17",9,0)
]

schema_1 = StructType([
    StructField("Date", StringType(),True),
    StructField("Val", IntegerType(),True),
    StructField("Cond", IntegerType(),True)
  ])

df_1 = spark.createDataFrame(data=data_1,schema=schema_1)
df_1 = df_1.withColumn('Date', to_date("Date", "yyyy-MM-dd"))

+----------+---+----+
|      Date|Val|Cond|
+----------+---+----+
|2022-01-08|  2|   0|
|2022-01-09|  4|   1|
|2022-01-10|  6|   1|
|2022-01-11|  8|   0|
|2022-01-12|  2|   1|
|2022-01-13|  5|   1|
|2022-01-14|  7|   0|
|2022-01-15|  9|   0|
|2022-01-16|  9|   0|
|2022-01-17|  9|   0|
+----------+---+----+

Create a new DF only with Cond==1 rows to obtain the sum of two consecutive rows with that condition:

windowSpec = Window.partitionBy("Cond").orderBy("Date")
df_2 = df_1.where(df_1.Cond==1).withColumn(
    "Sum",
    sum("Val").over(windowSpec.rowsBetween(-1, 0))
).withColumn('date_1', col('date')).drop('date')

+---+----+---+----------+
|Val|Cond|Sum|    date_1|
+---+----+---+----------+
|  4|   1|  4|2022-01-09|
|  6|   1| 10|2022-01-10|
|  2|   1|  8|2022-01-12|
|  5|   1|  7|2022-01-13|
+---+----+---+----------+

Do a left join to get the sum into the original data frame, and set the sum to zero for the rows with Cond==0:

df_3 = df_1.join(df_2.select('sum', col('date_1')), df_1.Date == df_2.date_1, "left").drop('date_1').fillna(0)

+----------+---+----+---+
|      Date|Val|Cond|sum|
+----------+---+----+---+
|2022-01-08|  2|   0|  0|
|2022-01-09|  4|   1|  4|
|2022-01-10|  6|   1| 10|
|2022-01-11|  8|   0|  0|
|2022-01-12|  2|   1|  8|
|2022-01-13|  5|   1|  7|
|2022-01-14|  7|   0|  0|
|2022-01-15|  9|   0|  0|
|2022-01-16|  9|   0|  0|
|2022-01-17|  9|   0|  0|
+----------+---+----+---+

Do a cumulative sum on the condition column:

df_3=df_3.withColumn('cond_sum', sum('cond').over(Window.orderBy('Date')))

+----------+---+----+---+--------+
|      Date|Val|Cond|sum|cond_sum|
+----------+---+----+---+--------+
|2022-01-08|  2|   0|  0|       0|
|2022-01-09|  4|   1|  4|       1|
|2022-01-10|  6|   1| 10|       2|
|2022-01-11|  8|   0|  0|       2|
|2022-01-12|  2|   1|  8|       3|
|2022-01-13|  5|   1|  7|       4|
|2022-01-14|  7|   0|  0|       4|
|2022-01-15|  9|   0|  0|       4|
|2022-01-16|  9|   0|  0|       4|
|2022-01-17|  9|   0|  0|       4|
+----------+---+----+---+--------+

Finally, for each partition where the cond_sum is greater than 1, use the max sum for that partition:

df_3.withColumn('sum', when(df_3.cond_sum > 1, max('sum').over(Window.partitionBy('cond_sum'))).otherwise(0)).show()

+----------+---+----+---+--------+
|      Date|Val|Cond|sum|cond_sum|
+----------+---+----+---+--------+
|2022-01-08|  2|   0|  0|       0|
|2022-01-09|  4|   1|  0|       1|
|2022-01-10|  6|   1| 10|       2|
|2022-01-11|  8|   0| 10|       2|
|2022-01-12|  2|   1|  8|       3|
|2022-01-13|  5|   1|  7|       4|
|2022-01-14|  7|   0|  7|       4|
|2022-01-15|  9|   0|  7|       4|
|2022-01-16|  9|   0|  7|       4|
|2022-01-17|  9|   0|  7|       4|
+----------+---+----+---+--------+
Domi
  • 168
  • 5
  • 13