First things first, hope I am formatting my question correctly.
I have this dataframe:
df = sc.parallelize([
('1112', 1, 0, 1, '2018-05-01'),
('1111', 1, 1, 1, '2018-05-01'),
('1111', 1, 3, 2, '2018-05-04'),
('1111', 1, 1, 2, '2018-05-05'),
('1111', 1, 1, 2, '2018-05-06'),
]).toDF(["customer_id", "buy_count", "date_difference", "expected_answer", "date"]).cache()
df.show()
+-----------+---------+---------------+---------------+----------+
|customer_id|buy_count|date_difference|expected_answer| date|
+-----------+---------+---------------+---------------+----------+
| 1111| 1| 1| 1|2018-05-01|
| 1111| 1| 3| 2|2018-05-04|
| 1111| 1| 1| 2|2018-05-05|
| 1111| 1| 1| 2|2018-05-06|
| 1112| 1| 0| 1|2018-05-01|
+-----------+---------+---------------+---------------+----------+
I want to create the "expected_answer" column:
If a customer hasn't bought for more than 3 days (date_difference >=3), I want to increase his buy_count by 1. Every purchase after that needs to have the new buy_count unless he doesn't buy for another 3 days in which case buy_count will increase again.
Here is my code and how far I have gotten with it. The problem seems to be that spark does not actually impute value but creates a new column. Is there a way to get past this? I also tried with Hive, exactly same results.
from pyspark.sql.window import Window
import pyspark.sql.functions as func
from pyspark.sql.functions import when
windowSpec = func.lag(df['buy_count']).\
over(Window.partitionBy(df['customer_id']).\
orderBy(df['date'].asc()))
df.withColumn('buy_count', \
when(df['date_difference'] >=3, windowSpec +1).when(windowSpec.isNull(), 1)\
.otherwise(windowSpec)).show()
+-----------+---------+---------------+---------------+----------+
|customer_id|buy_count|date_difference|expected_answer| date|
+-----------+---------+---------------+---------------+----------+
| 1112| 1| 0| 1|2018-05-01|
| 1111| 1| 1| 1|2018-05-01|
| 1111| 2| 3| 2|2018-05-04|
| 1111| 1| 1| 2|2018-05-05|
| 1111| 1| 1| 2|2018-05-06|
+-----------+---------+---------------+---------------+----------+
How can I get the expected result? Thanks in advance.