0

I have the below code where I need to reuse the flag from the previous day. So I am running the loop. I can't use the offset here as once I know the flag from the previous day then only I can use it for today. So, this loop runs 1000 times and after this whenever I try to do some operations on 'data_wt_flag1', it takes too much time, and after a while that results in a "spark driver stopped unexpectedly" error. I believe it is due to a memory issue. Is there a better way to write this logic? As I mentioned I can't use offset.

DateList=data.select("Date").distinct().orderBy('AsOfDate').rdd.flatMap(lambda x: x).collect()
Flag_list=[]

data_wt_flag1=spark.createDataFrame(data = [],schema = StructType([]))

for daily_date in DateList:
  print(daily_date)  
  Temp_data_daily=data.filter(col("Date").isin(daily_date))
  Temp_data_daily=Temp_data_daily.withColumn('lag_1',when(col("identifier").isin(Flag_list),1).otherwise(0))
    
  Temp_data_daily=Temp_data_daily.withColumn("condition_1", when(((col('col_1')==1) & ((col('col_2')==1) | (col('lag_1')==1))),1).otherwise(0))
  Flag_list=Temp_data_daily.filter(col('condition_1')==1).select("identifier").distinct().rdd.flatMap(lambda x: x).collect() 
  data_wt_flag1=data_wt_flag1.unionByName(Temp_data_daily,allowMissingColumns=True)

Logic of the code in Word:

If (col_1==1 and (col_2==1 or yesterday(condition_1)==1)) then today(condition_1)=1 otherwise 0.

So for the first date in the data, yesterday(condition_1) will be 0 for all identifiers so I am passing the null flag_list initially in a loop and then it will keep changing in every iteration and will be used to flag the identifier in the next iteration thus creating the lag_condition_1

Below is the sample data. I have only shown the columns which are required.

Identifier Date col_1 col_2
ABC 2023-08-20 1 1
GHI 2023-08-20 0 0
ABC 2023-08-21 1 0
GHI 2023-08-21 1 0
ABC 2023-08-22 1 0
GHI 2023-08-22 1 0
ABC 2023-08-23 1 0
GHI 2023-08-23 0 0

Below Table show the desired output.

Identifier Date col_1 col_2 lag_condition_1 condition_1
ABC 2023-08-20 1 1 0 1
GHI 2023-08-20 0 0 0 0
ABC 2023-08-21 1 0 1 1
GHI 2023-08-21 1 0 0 0
ABC 2023-08-22 1 0 1 1
GHI 2023-08-22 1 0 0 0
ABC 2023-08-23 1 0 1 1
GHI 2023-08-23 0 0 0 0

Here, for the first date, all lag_1 are zeros as I am passing an empty list. Then for the second date, we will have lag_1=1 for ABC and MNO as it has condition_1=1 in the previous date.

ASD
  • 25
  • 6
  • you can try to utilize Spark's built-in capabilities to handle this type of problem more efficiently. In your case, you're essentially trying to propagate flags based on some conditions across multiple days. You can leverage window functions for this purpose. – Robby star Aug 16 '23 at 18:34
  • `Thanks. But I am not sure what you mean. Do you mean something like this? Temp_data_daily=Temp_data_daily.withColumn('lag_1',lag("condition_1",offset=1).over(Window.partition('identifier').orderBy('identifier','Date'))) But this is not same as what I want. Can you let me know how can I do it? – ASD Aug 16 '23 at 21:36
  • You call `collect()` in a for loop, which can lead to performance and memory issues on the driver where the data is ultimately **collected**, i.e. received from worker nodes. I agree with the commenter above, there's probably a way to write your code in a way that will allow Spark to process your data more efficiently and in a more distributed fashion. Could you perhaps post samples of your input data and the desired state of the output, so we have a better idea of what you need and possibly hint you towards the right direction? – aidar_ms Aug 16 '23 at 22:53
  • I have edited the question with sample input and output. – ASD Aug 17 '23 at 08:19

1 Answers1

0

Yeah, a window function is your go-to:

import org.apache.spark.sql.{Row, functions => F}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types.{DateType, StructType, StructField, StringType, IntegerType}

// Sample data
val data = Seq(
    ("ABC", "2023-08-20", 1),
    ("DEF", "2023-08-20", 0),
    ("GHI", "2023-08-20", 0),
    ("MNO", "2023-08-20", 1),
    ("XYZ", "2023-08-20", 0),
    ("ABC", "2023-08-21", 0),
    ("DEF", "2023-08-21", 1),
    ("GHI", "2023-08-21", 0),
    ("MNO", "2023-08-21", 0),
    ("XYZ", "2023-08-21", 0),
)

// Optional: define the schema
val schema = StructType(Seq(
  StructField("Identifier", StringType, nullable = false),
  StructField("Date", DateType, nullable = false),
  StructField("col_1", IntegerType, nullable = false)
))

// Create RDD of Rows
val rowsRDD = spark.sparkContext.parallelize(data.map { case (id, date, col1) => Row(id, java.sql.Date.valueOf(date), col1) })

// Create DataFrame
var df = spark.createDataFrame(rowsRDD, schema)

// Define operations on your data
val windowSpec = Window.partitionBy("Identifier").orderBy("Date")

df = df.withColumn("lag_1", F.lag("col_1", 1, 0).over(windowSpec))
df = df.withColumn(
  "condition_1",
  F.when(F.col("col_1") === 1 || F.col("lag_1") === 1, 1).otherwise(0)
)

df = df.orderBy("Date", "Identifier")

// Perform an action only once (at the end of your program) to collect results
val result = df.collect()  // Or do df.show() to print out the dataframe

I'm oversimplifying but as a rule of thumb, when you're calling methods of a Spark DataFrame (e.g. .orderBy()) or using stuff from Spark library (e.g. window functions like lag), you're utilising Spark as a distributed data processing tool (as it should be used). And after all the data operations are done, you collect results with collect() or show() to have the final result sent to your driver. As opposed to calling collect() in a loop to fetch intermediate results into an array (Flag_list in your snippet) and then iterating it on your driver for additional processing -- this is usually a sign of a bad practice.

aidar_ms
  • 68
  • 1
  • 7
  • Oh damn it, I thought your code was in Scala. Sorry, long day at work but I don't think it'd be hard to translate over to Python -- most of it is Spark API, which is roughly similar in both languages – aidar_ms Aug 17 '23 at 22:11
  • here, Lag_1 is not a lag of col_1 but lag of condition_1. Once I calculate the condition_1 for the first date I used it in the next iteration to create a lag_1. – ASD Aug 18 '23 at 12:38
  • I have tried to explain the logic in detail and added one more condition (col_2) to make it more clear. you can follow ABC identifier you will get the idea. is it possible without loop and collect()? – ASD Aug 19 '23 at 09:00