0

I have to find when ever a particular store changes its brand i need to populate the mthid. This should applied to every store.

+------+-----------+---------------+-------------+-------------+
|MTH_ID| store_id  |     brand     |    brndSales|   TotalSales|
+------+-----------+---------------+-------------+-------------+
|201801|      10941|            115|  80890.44900| 135799.66400|
|201712|      10941|            123| 517440.74500| 975893.79000|
|201711|      10941|            99 | 371501.92100| 574223.52300|
|201710|      10941|            115| 552435.57800| 746912.06700|
|201709|      10941|            115|1523492.60700|1871480.06800|
|201708|      10941|            115|1027698.93600|1236544.50900|
|201707|      10941|            33 |1469219.86900|1622949.53000|

Output Looks like the Following

+------+-----------+---------------+-------------+-------------+
|MTH_ID| store_id  |     brand     |    brndSales|   TotalSales|switchdate
+------+-----------+---------------+-------------+-------------+
|201801|      10941|            115|  80890.44900| 135799.66400| 201712
|201712|      10941|            123| 517440.74500| 975893.79000| 201711
|201711|      10941|            99 | 371501.92100| 574223.52300| 201710
|201710|      10941|            115| 552435.57800| 746912.06700| 201707
|201709|      10941|            115|1523492.60700|1871480.06800| 201707
|201708|      10941|            115|1027698.93600|1236544.50900| 201707
|201707|      10941|            33 |1469219.86900|1622949.53000| 201706

I thought of applying lag, but we need to check whether change in brand column. If there is no change in brand we have to populate when it last changed.

Input data

val data = Seq((201801,      10941,            115,  80890.44900, 135799.66400),(201712,      10941,            123, 517440.74500, 975893.79000),(201711,      10941,            99 , 371501.92100, 574223.52300),(201710,      10941,            115, 552435.57800, 746912.06700),(201709,      10941,            115,1523492.60700,1871480.06800),(201708,      10941,            115,1027698.93600,1236544.50900),(201707,      10941,            33 ,1469219.86900,1622949.53000)).toDF("MTH_ID", "store_id" ,"brand" ,"brndSales","TotalSales")

Output from the response

+------+--------+-----+-----------+-----------+---------------+---+----------+
|MTH_ID|store_id|brand|  brndSales| TotalSales|prev_brand_flag|grp|switchdate|
+------+--------+-----+-----------+-----------+---------------+---+----------+
|201801|   10941|  115|  80890.449| 135799.664|              1|  5|    201801|
|201712|   10941|  123| 517440.745|  975893.79|              1|  4|    201712|
|201711|   10941|   99| 371501.921| 574223.523|              1|  3|    201711|
|201710|   10941|  115| 552435.578| 746912.067|              0|  2|    201708|
|201709|   10941|  115|1523492.607|1871480.068|              0|  2|    201708|
|201708|   10941|  115|1027698.936|1236544.509|              1|  2|    201708|
|201707|   10941|   33|1469219.869| 1622949.53|              1|  1|    201707|
+------+--------+-----+-----------+-----------+---------------+---+----------+

Should is there any available functions that can suffice the purpose

loneStar
  • 3,780
  • 23
  • 40
  • Possible duplicate of [Spark SQL window function with complex condition](https://stackoverflow.com/questions/42448564/spark-sql-window-function-with-complex-condition) – user10938362 May 21 '19 at 18:17
  • @user10938362 The logic there is way simple. I think this questions is not solved by lag we might need udfs – loneStar May 21 '19 at 19:09
  • @user10938362 Can you remove the duplicate tag, because that questions doesnt answer my question – loneStar May 21 '19 at 19:42
  • why should the switchdate be from the *previous* row and not from where the value changed? – Vamsi Prabhala May 21 '19 at 20:47
  • @VamsiPrabhala Yes that is correct. but how can i apply that?. I mean does window functions help – loneStar May 21 '19 at 20:50

1 Answers1

1

PySpark solution.

Use lag with a running sum to check if the value changed from the previous row and if so, increment a counter to set groups. Once grouping is done, it is about getting the min date per group.

w1 = Window.partitionBy(df.store_id).orderBy(df.mth_id)
df = df.withColumn('prev_brand_flag',when(lag(df.brand).over(w1) == df.brand,0).otherwise(1))
df = df.withColumn('grp',sum(df.prev_brand_flag).over(w1))
w2 = Window.partitionBy(df.store_id,df.grp)
res = df.withColumn('switchdate',min(df.mth_id).over(w2))
res.show()

Looking at the results of the intermediate dataframes will give you an idea of how the logic works.

Vamsi Prabhala
  • 48,685
  • 4
  • 36
  • 58
  • Wow the way you create the sum to create a new window. Super. I am getting good output but in my scenario switch_date is previous month not current month – loneStar May 22 '19 at 00:36