what I want to do :
In PySpark, I am trying to distribute N rows into X groups of same size and to attribute a specific value D to those groups.
- Each row consist of attributes A, B, C (reference, item, location) where all A are unique, but not B & C.
- X is a constant declared upstream
- D is a date from D1 = Today + 1 to Dx = Today + X
- Rows where the combination {B;C} is equal should not be split among different groups (same items for same location should not be split and get the same date)
What I have (df1) :
df1 = spark.createDataFrame([ ('1234','banana','Paris'),
('1235','orange','Berlin'),
('1236','orange','Paris'),
('1237','banana','Berlin'),
('1238','orange','Paris'),
('1239','banana','Berlin'),
], ["A","B","C"])
+----+------+------+
| A| B| C|
+----+------+------+
|1234|banana| Paris|
|1235|orange|Berlin|
|1236|orange| Paris|
|1237|banana|Berlin|
|1238|orange| Paris|
|1239|banana|Berlin|
+----+------+------+
What I want (df2) :
e.g. when X = 3:
+----+------+------+-----+
| A| B| C| D|
+----+------+------+-----+
|1234|banana| Paris|date1|
|1235|orange|Berlin|date1|
|1236|orange| Paris|date2|
|1237|banana|Berlin|date3|
|1238|orange| Paris|date2|
|1239|banana|Berlin|date3|
+----+------+------+-----+
e.g. when X = 4:
+----+------+------+-----+
| A| B| C| D|
+----+------+------+-----+
|1234|banana| Paris|date1|
|1235|orange|Berlin|date4|
|1236|orange| Paris|date2|
|1237|banana|Berlin|date3|
|1238|orange| Paris|date2|
|1239|banana|Berlin|date3|
+----+------+------+-----+
e.g. when X = 5:
+----+------+------+-----+
| A| B| C| D|
+----+------+------+-----+
|1234|banana| Paris|date1|
|1235|orange|Berlin|date4|
|1236|orange| Paris|date2|
|1237|banana|Berlin|date3|
|1238|orange| Paris|date2|
|1239|banana|Berlin|date3|
+----+------+------+-----+
note : the ranking of {B,C} elements can be random.
What I tried so far :
the following code distributes the elements equally but cannot respect the condition to not split similar {B;C} combinations
>>> w=Window.orderBy('B','C')
>>> df2 = df1.withColumn("id",(F.row_number().over(w))%3)
>>> df2.show()
+----+------+------+---+
| A| B| C| id|
+----+------+------+---+
|1237|banana|Berlin| 1|
|1239|banana|Berlin| 2|
|1234|banana| Paris| 0|
|1235|orange|Berlin| 1|
|1236|orange| Paris| 2|
|1238|orange| Paris| 0|
+----+------+------+---+