3

what I want to do :

In PySpark, I am trying to distribute N rows into X groups of same size and to attribute a specific value D to those groups.

  • Each row consist of attributes A, B, C (reference, item, location) where all A are unique, but not B & C.
  • X is a constant declared upstream
  • D is a date from D1 = Today + 1 to Dx = Today + X
  • Rows where the combination {B;C} is equal should not be split among different groups (same items for same location should not be split and get the same date)

What I have (df1) :

df1 = spark.createDataFrame([ ('1234','banana','Paris'),
                            ('1235','orange','Berlin'),
                            ('1236','orange','Paris'),
                            ('1237','banana','Berlin'),
                            ('1238','orange','Paris'),
                            ('1239','banana','Berlin'),
                       ], ["A","B","C"])

+----+------+------+
|   A|     B|     C|
+----+------+------+
|1234|banana| Paris|
|1235|orange|Berlin|
|1236|orange| Paris|
|1237|banana|Berlin|
|1238|orange| Paris|
|1239|banana|Berlin|
+----+------+------+

What I want (df2) :

e.g. when X = 3:

    +----+------+------+-----+
    |   A|     B|     C|    D|
    +----+------+------+-----+
    |1234|banana| Paris|date1|
    |1235|orange|Berlin|date1|
    |1236|orange| Paris|date2|
    |1237|banana|Berlin|date3|
    |1238|orange| Paris|date2|
    |1239|banana|Berlin|date3|
    +----+------+------+-----+

e.g. when X = 4:

    +----+------+------+-----+
    |   A|     B|     C|    D|
    +----+------+------+-----+
    |1234|banana| Paris|date1|
    |1235|orange|Berlin|date4|
    |1236|orange| Paris|date2|
    |1237|banana|Berlin|date3|
    |1238|orange| Paris|date2|
    |1239|banana|Berlin|date3|
    +----+------+------+-----+

               
                   

e.g. when X = 5:

    +----+------+------+-----+
    |   A|     B|     C|    D|
    +----+------+------+-----+
    |1234|banana| Paris|date1|
    |1235|orange|Berlin|date4|
    |1236|orange| Paris|date2|
    |1237|banana|Berlin|date3|
    |1238|orange| Paris|date2|
    |1239|banana|Berlin|date3|
    +----+------+------+-----+

               

note : the ranking of {B,C} elements can be random.


What I tried so far :

the following code distributes the elements equally but cannot respect the condition to not split similar {B;C} combinations

>>> w=Window.orderBy('B','C')
>>> df2 = df1.withColumn("id",(F.row_number().over(w))%3)
>>> df2.show()
+----+------+------+---+
|   A|     B|     C| id|
+----+------+------+---+
|1237|banana|Berlin|  1|
|1239|banana|Berlin|  2|
|1234|banana| Paris|  0|
|1235|orange|Berlin|  1|
|1236|orange| Paris|  2|
|1238|orange| Paris|  0|
+----+------+------+---+

                   
                   

ionah
  • 81
  • 1
  • 8

2 Answers2

1

Use dense_rank instead of row_number. If you mod by 3, you're not guaranteed to get equal sized groups but it will be close depending on the shuffle of your data. If it needs to be as exact as possible, you can split it doing something like floor(dense_rank_col / max(dense_rank_col) * 3)

  • Indeed, with dense_rank() all groups are not equally sized, but the variation seems to be ok: out of 4.310, with a distribution in 9 groups, size varies as following {464, 475, 482, 496, 481, 476, 476, 480, 480). However, while attempting to use the combination with floor(), I got the following : TypeError: Column is not iterable | `df2 = df1.withColumn("id",floor(dense_rank().over(w) / max(dense_rank().over(w))) * X)` - what did I get wrong ? – ionah Sep 21 '21 at 10:43
0

an alternative answer was proposed to me as following :


Make use of collect_list and explode :

df1 = spark.createDataFrame([ ('1234','banana','Paris'),
                            ('1235','orange','Berlin'),
                            ('1236','orange','Paris'),
                            ('1237','banana','Berlin'),
                            ('1238','orange','Paris'),
                            ('1239','banana','Berlin'),
                       ], ["A","B","C"])

from pyspark.sql import Window as W, functions as F

df = df1.groupBy("B", "C").agg(F.collect_list("A").alias("A"))\
        .withColumn("id", F.rand())\
        .withColumn("id", F.row_number().over(W.partitionBy().orderBy("id")) % 3)\
        .withColumn("A", F.explode("A"))\
df.show()

+------+------+----+---+
|     B|     C|   A| id|
+------+------+----+---+
|banana|Berlin|1237|  1|
|banana|Berlin|1239|  1|
|orange|Berlin|1235|  2|
|orange| Paris|1236|  0|
|orange| Paris|1238|  0|
|banana| Paris|1234|  1|
+------+------+----+---+

the result is fairly the same as the answer provided by PySpark Helper

ionah
  • 81
  • 1
  • 8