1

Given the following input dataframe

npos = 3

inp = spark.createDataFrame([
    ['1', 23, 0, 2],
    ['1', 45, 1, 2],
    ['1', 89, 1, 3],
    ['1', 95, 2, 2],
    ['1', 95, 0, 4],
    ['2', 20, 2, 2],
    ['2', 40, 1, 4],
  ], schema=["id","elap","pos","lbl"])

A dataframe which looks like this needs to be constructed

out = spark.createDataFrame([
    ['1', 23, [2,0,0]],
    ['1', 45, [2,2,0]],
    ['1', 89, [2,3,0]],
    ['1', 95, [4,3,2]],
    ['2', 20, [0,0,2]],
    ['2', 40, [0,4,2]],
  ], schema=["id","elap","vec"])

The input dataframe has 10s of millions of records.

Some details which are seen in the example above (by design)

  • npos is the size of the vector to be constructed in the output
  • pos is guaranteed to be in [0,npos)
  • at each time step (elap) there will be at most 1 label for a pos
  • if lbl is not given at a time step it has to be inferred from the last time it was specified for that pos
  • if lbl is not previously specified, it can be assumed to be 0
blackbishop
  • 30,945
  • 11
  • 55
  • 76
ironv
  • 978
  • 10
  • 25

1 Answers1

1

You can use some higher-order functions on arrays to achieve that:

  1. add vec column using array_repeat function and initialize pos value from lbl
  2. use collect_list to get cumulative vec over window partitioned by id
  3. aggregate the resulting array by selecting previous positions if it is different from 0
from pyspark.sql import Window
import pyspark.sql.functions as F

npos = 3

out = inp.withColumn(
    "vec",
    F.expr(f"transform(array_repeat(0, {npos}), (x, i) -> IF(i=pos, lbl, x))")
).withColumn(
    "vec",
    F.collect_list("vec").over(Window.partitionBy("id").orderBy("elap"))
).withColumn(
    "vec",
    F.expr(f"""aggregate(
                  vec, 
                  array_repeat(0, {npos}),
                  (acc, x) -> transform(acc, (y, i) -> int(IF(x[i]!=0, x[i], y)))
            )""")
).drop("lbl", "pos")

out.show(truncate=False)

#+---+----+---------+
#|id |elap|vec      |
#+---+----+---------+
#|1  |23  |[2, 0, 0]|
#|1  |45  |[2, 2, 0]|
#|1  |89  |[2, 3, 0]|
#|1  |95  |[4, 3, 2]|
#|1  |95  |[4, 3, 2]|
#|2  |20  |[0, 0, 2]|
#|2  |40  |[0, 4, 2]|
#+---+----+---------+
blackbishop
  • 30,945
  • 11
  • 55
  • 76
  • Thanks! Rows 4 and 5 are the same. Run a distinct on the entire df at the end or is there a better way? – ironv Jan 18 '22 at 04:18
  • To get another example, you can change `elap` in the second row to 23. Then rows 1 & 2 get duplicated. `distinct()` gets rid of that. Your thoughts? – ironv Jan 18 '22 at 04:20
  • @ironv if you don't care of which row to preserve then you could do `inp.dropDuplicates["id", "elap"]` before doing above transformations. Otherwise you can use use [Window with row_number](https://stackoverflow.com/questions/38687212/spark-dataframe-drop-duplicates-and-keep-first) to eliminate duplicates according to certain order you define. – blackbishop Jan 18 '22 at 07:26
  • Thx. I went with `.withColumn("rn", F.row_number().over(W.partitionBy("id","elap").orderBy(F.desc("pos")))).filter(F.col("rn")==1).drop("rn","pos","lbl")` after the `collect_list` line. – ironv Jan 18 '22 at 16:11
  • For the part of the code which says `IF(i=pos, lbl, x)` does there need to be `==`? Both seem to work though! – ironv Jan 19 '22 at 15:46
  • Can you pls point me to a good resource (with examples) which will help me learn about `pySpark` `tranform()`? Seems really powerful and an excellent tool to have in one's toolkit and know well. – ironv Feb 02 '22 at 17:57