3

I'm trying to uniquely label consecutive rows with equal values in a PySpark dataframe. In Pandas, one could do this quite simply with:

s = pd.Series([1,1,1,2,2,1,1,3])
s.ne(s.shift()).cumsum()
0    1
1    1
2    1
3    2
4    2
5    3
6    3
7    4
dtype: int64

How could this be done in PySpark? Setup -

from pyspark.sql.types import IntegerType
from pyspark.sql.types import StructType
spark = SparkSession.builder.appName('pandasToSparkDF').getOrCreate()

mySchema = StructType([StructField("col1", IntegerType(), True)])
df_sp = spark.createDataFrame(s.to_frame(), schema=mySchema)

I've found slightly related questions such as this one, but none of them about this same scenario.

I'm thinking a good starting point could be to find the first differences as in this answer

yatu
  • 86,083
  • 12
  • 84
  • 139

2 Answers2

2

I've come up with a solution. The idea is similar to what is done in Pandas. We start by adding an unique identifier column, over which we'll compute the lagged column (using over here is necessary since it is a window function).

We then compare the column of interest with the lagged column and take the cumulative sum of the result cast to int:

mySchema = StructType([StructField("col1", IntegerType(), True)])
df_sp = spark.createDataFrame(s.to_frame(), schema=mySchema)

win = Window.orderBy("id")
df_sp = (df_sp.withColumn("id", f.monotonically_increasing_id())
              .withColumn("col1_shift", f.lag("col1", offset=1, default=0).over(win))
              .withColumn("col1_shift_ne", (f.col("col1") != f.col("col1_shift")).cast("int"))
              .withColumn("col1_shift_ne_cumsum", f.sum("col1_shift_ne").over(win))
              .drop(*['id','col1_shift', 'col1_shift_ne']))

df_sp.show()
---+--------------------+
|col1|col1_shift_ne_cumsum|
+----+--------------------+
|   1|                   1|
|   1|                   1|
|   1|                   1|
|   2|                   2|
|   2|                   2|
|   1|                   3|
|   1|                   3|
|   3|                   4|
+----+--------------------+
yatu
  • 86,083
  • 12
  • 84
  • 139
0

Another way of solving this would be using a rangebetween and using unbounded preceeding sum after comparing the lag:

from pyspark.sql import functions as F, Window as W

w1 = W.orderBy(F.monotonically_increasing_id())
w2 = W.orderBy(F.monotonically_increasing_id()).rangeBetween(W.unboundedPreceding,0)

cond = F.col("col1") != F.lag("col1").over(w1)
df_sp.withColumn("col1_shift_ne_cumsum",F.sum(F.when(cond,1).otherwise(0)).over(w2)+1).show()

+----+--------------------+
|col1|col1_shift_ne_cumsum|
+----+--------------------+
|   1|                   1|
|   1|                   1|
|   1|                   1|
|   2|                   2|
|   2|                   2|
|   1|                   3|
|   1|                   3|
|   3|                   4|
+----+--------------------+
anky
  • 74,114
  • 11
  • 41
  • 70