0

I have the following data:

client_id,transaction_id,start,end,amount
1,1,2018-12-09,2018-12-11,1000
1,2,2018-12-19,2018-12-21,2000
1,3,2018-12-19,2018-12-31,3000
2,4,2018-11-09,2018-12-20,4000
2,5,2018-12-19,2018-12-21,5000
2,6,2018-12-22,2018-12-31,6000

Using PySpark I am trying to add a column that shows the number of finished transactions based on the start time of the current row. I was able to do this in Pandas using the fairly simple code as shown below:

import pandas as pd
df = pd.read_csv('transactions.csv')
df['closed_transactions'] = df.apply(lambda row: len(df[ (df['end'] < 
row['start']) & (df['client_id'] == row['client_id'])]), axis=1) 

Resulting in the dataframe

client_id   transaction_id  start   end amount  closed_transactions
0   1   1   2018-12-09  2018-12-11  1000    0
1   1   2   2018-12-19  2018-12-21  2000    1
2   1   3   2018-12-19  2018-12-31  3000    1
3   2   4   2018-11-09  2018-12-20  4000    0
4   2   5   2018-12-19  2018-12-21  5000    0
5   2   6   2018-12-22  2018-12-31  6000    2

However to achieve the same in PySpark I struggle to get the same thing working. I am able to add a simple counter per group using the Window function and the cumulative sum is working too, but I am unable to get the amount of closed transactions given the data for the current row.

from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext, Window
import pyspark.sql.functions as psf

config = SparkConf().setMaster('local')
spark = SparkContext.getOrCreate(conf=config)
sqlContext = SQLContext(spark)

spark_df = sqlContext.read.csv('transactions.csv', header=True)
window = Window.partitionBy('client_id').orderBy('start').rowsBetween(Window.unboundedPreceding, 0)

@psf.udf('string')
def get_number_of_transactions(curr_date):
    return spark_df[spark_df['end'] < curr_date].count()

spark_df \
    .withColumn('number_of_past_transactions', 
psf.row_number().over(window) - 1) \
    .withColumn('total_amount', psf.sum(psf.col('amount')).over(window)) \
   .withColumn('closed_transactions', 
get_number_of_transactions(psf.col('end'))) \
    .show()

The workaround I have now is to convert the Spark dataframe to Pandas and broadcast it, so I can use it in the UDF, but I was hoping there would be a more elegant solution to solve this issue.

Any help is much appreciated!

JQadrad
  • 541
  • 2
  • 16
  • Joining the dataframe with self on `client ID` and then adding a flag for `df['end']>df['start']` should work with grouping thereafter on `start`. – Rahul Chawla Dec 19 '18 at 13:48

1 Answers1

0

As I mentioned in my comment, joining the dataframe with self on client_id and adding a boolean column where start_date<end_date. Now we can group on start_date taking sum of this boolean column.

from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext, Window
import pyspark.sql.functions as psf

config = SparkConf().setMaster('local')
spark = SparkContext.getOrCreate(conf=config)
sqlContext = SQLContext(spark)

spark_df = sqlContext.read.csv('transactions.csv', header=True)

# Renaming columns for self join
df2 = spark_df
for c in df.columns:
    df2 = df2 if c == 'client_id' else df2.withColumnRenamed(c, 'x_{cl}'.format(cl=c))

# Joining with self on client ID
new_df = spark_df.join(df2, 'header')

# Creating the flag column and summing it by grouping on start_date
new_df = new_df.withColumn('valid_transaction', when(col('start_date')<col('x_end_date'),1).otherwise(0)).groupBy(['client_id', 'start_date']).agg(sum('valid_transaction'))
Rahul Chawla
  • 1,048
  • 10
  • 15
  • Super. Seems to work well after few small modifications. Is it save to use `df2 = spark_df`? Should we use a deep_copy() to not reference the wrong object? – JQadrad Dec 19 '18 at 15:43
  • I think this method will work. More details here https://stackoverflow.com/a/52289158/6775799. Don't forget to accept the answer if this was what you were looking for. – Rahul Chawla Dec 19 '18 at 16:07