1

First things first, hope I am formatting my question correctly.
I have this dataframe:

df = sc.parallelize([
('1112', 1, 0, 1, '2018-05-01'),
('1111', 1, 1, 1, '2018-05-01'),
('1111', 1, 3, 2, '2018-05-04'),
('1111', 1, 1, 2, '2018-05-05'),
('1111', 1, 1, 2, '2018-05-06'),
]).toDF(["customer_id", "buy_count", "date_difference", "expected_answer", "date"]).cache()

df.show()
+-----------+---------+---------------+---------------+----------+
|customer_id|buy_count|date_difference|expected_answer|      date|
+-----------+---------+---------------+---------------+----------+
|       1111|        1|              1|              1|2018-05-01|
|       1111|        1|              3|              2|2018-05-04|
|       1111|        1|              1|              2|2018-05-05|
|       1111|        1|              1|              2|2018-05-06|
|       1112|        1|              0|              1|2018-05-01|
+-----------+---------+---------------+---------------+----------+

I want to create the "expected_answer" column:

If a customer hasn't bought for more than 3 days (date_difference >=3), I want to increase his buy_count by 1. Every purchase after that needs to have the new buy_count unless he doesn't buy for another 3 days in which case buy_count will increase again.

Here is my code and how far I have gotten with it. The problem seems to be that spark does not actually impute value but creates a new column. Is there a way to get past this? I also tried with Hive, exactly same results.

from pyspark.sql.window import Window
import pyspark.sql.functions as func
from pyspark.sql.functions import when

windowSpec = func.lag(df['buy_count']).\
over(Window.partitionBy(df['customer_id']).\
orderBy(df['date'].asc()))

df.withColumn('buy_count', \
              when(df['date_difference'] >=3, windowSpec +1).when(windowSpec.isNull(), 1)\
              .otherwise(windowSpec)).show()

+-----------+---------+---------------+---------------+----------+
|customer_id|buy_count|date_difference|expected_answer|      date|
+-----------+---------+---------------+---------------+----------+
|       1112|        1|              0|              1|2018-05-01|
|       1111|        1|              1|              1|2018-05-01|
|       1111|        2|              3|              2|2018-05-04|
|       1111|        1|              1|              2|2018-05-05|
|       1111|        1|              1|              2|2018-05-06|
+-----------+---------+---------------+---------------+----------+

How can I get the expected result? Thanks in advance.

zero323
  • 322,348
  • 103
  • 959
  • 935
Toutsos
  • 349
  • 1
  • 4
  • 11
  • Possible duplicate of [Spark - Window with recursion? - Conditionally propagating values across rows](https://stackoverflow.com/questions/45277487/spark-window-with-recursion-conditionally-propagating-values-across-rows) – zero323 Oct 23 '18 at 09:45
  • You can also check [Spark SQL window function with complex condition](https://stackoverflow.com/q/42448564/6910411) which shows the same pattern with dates, but in Scala. – zero323 Oct 23 '18 at 09:51

1 Answers1

0

Figured it out at last. Thanks everyone for pointing out similar cases.

I was under the impression that SUM() over Partition would sum over the whole partition and not just sum everything before current row. Luckily, I was able to solve my problem with a very simple SQL:

SELECT SUM(CASE WHEN(date_difference>=3) THEN 1 ELSE 0 END) OVER (PARTITION BY customer_id ORDER BY date) 
       FROM df

sqlContext.sql(qry).show()
Toutsos
  • 349
  • 1
  • 4
  • 11