I have the following data:
client_id,transaction_id,start,end,amount
1,1,2018-12-09,2018-12-11,1000
1,2,2018-12-19,2018-12-21,2000
1,3,2018-12-19,2018-12-31,3000
2,4,2018-11-09,2018-12-20,4000
2,5,2018-12-19,2018-12-21,5000
2,6,2018-12-22,2018-12-31,6000
Using PySpark I am trying to add a column that shows the number of finished transactions based on the start time of the current row. I was able to do this in Pandas using the fairly simple code as shown below:
import pandas as pd
df = pd.read_csv('transactions.csv')
df['closed_transactions'] = df.apply(lambda row: len(df[ (df['end'] <
row['start']) & (df['client_id'] == row['client_id'])]), axis=1)
Resulting in the dataframe
client_id transaction_id start end amount closed_transactions
0 1 1 2018-12-09 2018-12-11 1000 0
1 1 2 2018-12-19 2018-12-21 2000 1
2 1 3 2018-12-19 2018-12-31 3000 1
3 2 4 2018-11-09 2018-12-20 4000 0
4 2 5 2018-12-19 2018-12-21 5000 0
5 2 6 2018-12-22 2018-12-31 6000 2
However to achieve the same in PySpark I struggle to get the same thing working. I am able to add a simple counter per group using the Window function and the cumulative sum is working too, but I am unable to get the amount of closed transactions given the data for the current row.
from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext, Window
import pyspark.sql.functions as psf
config = SparkConf().setMaster('local')
spark = SparkContext.getOrCreate(conf=config)
sqlContext = SQLContext(spark)
spark_df = sqlContext.read.csv('transactions.csv', header=True)
window = Window.partitionBy('client_id').orderBy('start').rowsBetween(Window.unboundedPreceding, 0)
@psf.udf('string')
def get_number_of_transactions(curr_date):
return spark_df[spark_df['end'] < curr_date].count()
spark_df \
.withColumn('number_of_past_transactions',
psf.row_number().over(window) - 1) \
.withColumn('total_amount', psf.sum(psf.col('amount')).over(window)) \
.withColumn('closed_transactions',
get_number_of_transactions(psf.col('end'))) \
.show()
The workaround I have now is to convert the Spark dataframe to Pandas and broadcast it, so I can use it in the UDF, but I was hoping there would be a more elegant solution to solve this issue.
Any help is much appreciated!