It looks like you don't have a partition for your window, and the events do not have the same amount of records. Considering this, the solution that comes to my mind is to use the position of each event start to retrieve the respective value.
Considering the sorting by timestamp, we extract the position of each line:
from pyspark.sql import Window
from pyspark.sql.functions import col, rank, collect_list, expr
df = (
spark.createDataFrame(
[
{ 'timestamp': '2021-02-02 01:03:55', 'col1': 's1' },
{ 'timestamp': '2021-02-02 01:04:16.952854', 'col1': 's1', 'col2': 'other_ind'},
{ 'timestamp': '2021-02-02 01:04:32.398155', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:04:53.793089', 'col1': 's1', 'col2': 'event_start_ind', 'col3': 'event_1_value'},
{ 'timestamp': '2021-02-02 01:05:10.936913', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:05:36', 'col1': 's1', 'col2': 'other_ind'},
{ 'timestamp': '2021-02-02 01:05:42', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:05:43', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:05:44', 'col1': 's1', 'col2': 'event_start_ind', 'col3': 'event_2_value'},
{ 'timestamp': '2021-02-02 01:05:46.623198', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:06:50', 'col1': 's1'},
{ 'timestamp': '2021-02-02 01:07:19.607685', 'col1': 's1'}
]
)
.withColumn('timestamp', col('timestamp').cast('timestamp'))
.withColumn("line", rank().over(Window.orderBy("timestamp")))
)
df.show(truncate=False)
+----+--------------------------+---------------+-------------+----+
|col1|timestamp |col2 |col3 |line|
+----+--------------------------+---------------+-------------+----+
|s1 |2021-02-02 01:03:55 |null |null |1 |
|s1 |2021-02-02 01:04:16.952854|other_ind |null |2 |
|s1 |2021-02-02 01:04:32.398155|null |null |3 |
|s1 |2021-02-02 01:04:53.793089|event_start_ind|event_1_value|4 |
|s1 |2021-02-02 01:05:10.936913|null |null |5 |
|s1 |2021-02-02 01:05:36 |other_ind |null |6 |
|s1 |2021-02-02 01:05:42 |null |null |7 |
|s1 |2021-02-02 01:05:43 |null |null |8 |
|s1 |2021-02-02 01:05:44 |event_start_ind|event_2_value|9 |
|s1 |2021-02-02 01:05:46.623198|null |null |10 |
|s1 |2021-02-02 01:06:50 |null |null |11 |
|s1 |2021-02-02 01:07:19.607685|null |null |12 |
+----+--------------------------+---------------+-------------+----+
After that we identify each event start:
df_event_start = (
df.filter(col("col2") == 'event_start_ind')
.select(
col("line").alias("event_start_line"),
col("col3").alias("event_value")
)
)
df_event_start.show()
+----------------+-------------+
|event_start_line| event_value|
+----------------+-------------+
| 4|event_1_value|
| 9|event_2_value|
+----------------+-------------+
Uses event_start
information to find the next valid event start:
df_with_event_starts = (
df.join(
df_event_start.select(collect_list('event_start_line').alias("event_starts"))
)
.withColumn("next_valid_event", expr("element_at(filter(event_starts, x -> x >= line), 1)"))
)
df_with_event_starts.show(truncate=False)
+----+--------------------------+---------------+-------------+----+------------+----------------+
|col1|timestamp |col2 |col3 |line|event_starts|next_valid_event|
+----+--------------------------+---------------+-------------+----+------------+----------------+
|s1 |2021-02-02 01:03:55 |null |null |1 |[4, 9] |4 |
|s1 |2021-02-02 01:04:16.952854|other_ind |null |2 |[4, 9] |4 |
|s1 |2021-02-02 01:04:32.398155|null |null |3 |[4, 9] |4 |
|s1 |2021-02-02 01:04:53.793089|event_start_ind|event_1_value|4 |[4, 9] |4 |
|s1 |2021-02-02 01:05:10.936913|null |null |5 |[4, 9] |9 |
|s1 |2021-02-02 01:05:36 |other_ind |null |6 |[4, 9] |9 |
|s1 |2021-02-02 01:05:42 |null |null |7 |[4, 9] |9 |
|s1 |2021-02-02 01:05:43 |null |null |8 |[4, 9] |9 |
|s1 |2021-02-02 01:05:44 |event_start_ind|event_2_value|9 |[4, 9] |9 |
|s1 |2021-02-02 01:05:46.623198|null |null |10 |[4, 9] |null |
|s1 |2021-02-02 01:06:50 |null |null |11 |[4, 9] |null |
|s1 |2021-02-02 01:07:19.607685|null |null |12 |[4, 9] |null |
+----+--------------------------+---------------+-------------+----+------------+----------------+
And finally retrieves the correct value:
(
df_with_event_starts.join(
df_event_start,
col("next_valid_event") == col("event_start_line"),
how="left"
)
.drop("line", "event_starts", "next_valid_event", "event_start_line")
.show(truncate=False)
)
+----+--------------------------+---------------+-------------+-------------+
|col1|timestamp |col2 |col3 |event_value |
+----+--------------------------+---------------+-------------+-------------+
|s1 |2021-02-02 01:03:55 |null |null |event_1_value|
|s1 |2021-02-02 01:04:16.952854|other_ind |null |event_1_value|
|s1 |2021-02-02 01:04:32.398155|null |null |event_1_value|
|s1 |2021-02-02 01:04:53.793089|event_start_ind|event_1_value|event_1_value|
|s1 |2021-02-02 01:05:10.936913|null |null |event_2_value|
|s1 |2021-02-02 01:05:36 |other_ind |null |event_2_value|
|s1 |2021-02-02 01:05:42 |null |null |event_2_value|
|s1 |2021-02-02 01:05:43 |null |null |event_2_value|
|s1 |2021-02-02 01:05:44 |event_start_ind|event_2_value|event_2_value|
|s1 |2021-02-02 01:05:46.623198|null |null |null |
|s1 |2021-02-02 01:06:50 |null |null |null |
|s1 |2021-02-02 01:07:19.607685|null |null |null |
+----+--------------------------+---------------+-------------+-------------+
This solution will bring you problems in processing large volumes.
If you can figure out a key for each event, I advise you to continue with your initial solution using window functions. If this happens, you can test the last
or first
sql function (ignoring the null values).
Hopefully, someone will help you with a better solution.
Tip: Making the data frame creation scripts available in the question is helpful.