0

Dear PySpark community:

I would like to calculate the estimate_day_to_sustain before supply. The original code is written in SAS using 'retain' statement, however, I cannot find a way to solve it in PySpark. Please help, thanks!

input data:

enter image description here

output data:

enter image description here

Algorithm:

  1. On day 1: current estimate_day_to_sustain=current day

  2. On other days:

    1> if previous estimate_day_to_sustain + previous supply <= current day; then current estimate_day_to_sustain = current day

    2> else current estimate_day_to_sustain= previous estimate_day_to_sustain + previous supply

Explanation of the algorithm:

  1. on day 1: the estimate_day_to_sustain is 1; by the end of the day, 3 days of supply arrive
  2. on day 3, we have 1+3=4 days of supply (from previous row), and it's day 3, so the estimate_day_to_sustain is 4;by the end of the day, 1 days of supply arrive
  3. on day 9, we have 4+1=5 days of supply (from previous row), but it's already day 9, so the estimate_day_to_sustain is 9(this is the tricky part);by the end of the day, 5 days of supply arrive
  4. on day 10, we have 9+5=14 days of supply (from previous row), and it's day 10, so the estimate_day_to_sustain is 14;by the end of the day, 9 days of supply arrive
  5. on day 11, we have 14+9=23 days of supply (from previous row), and it's day 11, so the estimate_day_to_sustain is 23;by the end of the day,1 days of supply arrive
Bobby
  • 27
  • 3

1 Answers1

0

Two ways to do it -

using your input data

data_sdf.show()

# +---+------+
# |day|supply|
# +---+------+
# |  1|     3|
# |  3|     1|
# |  9|     5|
# | 10|     9|
# | 11|     1|
# +---+------+

Spark does not retain a sort order like SAS data steps do. So, we will have to sort the array or list wherever required.


Using arrays, aggregate() and lambda function

SPARK 3.1+

create day-supply structs and collect it to create an array of the said structs. The array_sort() is used to order the array of structs by the day field (first element within struct). The aggregate() takes in an initial value and applies a function to each element of the provided array. So, the initial value is the array's first struct, and the lambda function is applied to each of the remaining structs. The array_union() is used to append the newly created array, after applying the lambda function, to the initial value recursively. Finally, the inline() function is used to create separate columns from the newly created array of structs. Detailed explanation on its workings can be found in this answer.

data2_sdf = data_sdf. \
    withColumn('day_supply_struct', func.struct(func.col('day'), func.col('supply'))). \
    groupBy(func.lit(1).alias('group_field')). \
    agg(func.array_sort(func.collect_list('day_supply_struct')).alias('ds_struct_arr'))

# +-----------+------------------------------------------+
# |group_field|ds_struct_arr                             |
# +-----------+------------------------------------------+
# |1          |[{1, 3}, {3, 1}, {9, 5}, {10, 9}, {11, 1}]|
# +-----------+------------------------------------------+

# create new field within the struct
data3_sdf = data2_sdf. \
    withColumn('arr_struct_w_est_days', 
               func.aggregate(func.slice(func.col('ds_struct_arr'), 2, data_sdf.count()), 
                              func.array(func.col('ds_struct_arr')[0].withField('estimate_days', func.col('ds_struct_arr')[0]['day'])), 
                              lambda x, y: func.array_union(x, 
                                                            func.array(y.withField('estimate_days', 
                                                                                   func.when(func.element_at(x, -1)['estimate_days'] + func.element_at(x, -1)['supply'] <= y['day'], y['day']).
                                                                                   otherwise(func.element_at(x, -1)['estimate_days'] + func.element_at(x, -1)['supply'])
                                                                                   )
                                                                       )
                                                            )
                              )
    )

# +-----------+------------------------------------------+-----------------------------------------------------------+
# |group_field|ds_struct_arr                             |arr_struct_w_est_days                                      |
# +-----------+------------------------------------------+-----------------------------------------------------------+
# |1          |[{1, 3}, {3, 1}, {9, 5}, {10, 9}, {11, 1}]|[{1, 3, 1}, {3, 1, 4}, {9, 5, 9}, {10, 9, 14}, {11, 1, 23}]|
# +-----------+------------------------------------------+-----------------------------------------------------------+

# create columns using the struct fields
data3_sdf. \
    selectExpr('inline(arr_struct_w_est_days)'). \
    show()

# +---+------+-------------+
# |day|supply|estimate_days|
# +---+------+-------------+
# |  1|     3|            1|
# |  3|     1|            4|
# |  9|     5|            9|
# | 10|     9|           14|
# | 11|     1|           23|
# +---+------+-------------+

using rdd, flatMapValues() and python function

create a python function to calculate estimate days while keeping track of the previous calculated values. It takes in a group of row-lists (groupBy() is used to identify the grouping) and creates a list of rows-lists.

# best to ship this function to all executors in case of huge datasets
def estimateDaysCalc(groupedRows):

    res = []
    frstRec = True

    for row in groupedRows:
        if frstRec:
            frstRec = False
            # the first day will have a static value
            estimate_days = row.day
        else:
            if prev_est_day + prev_supply <= row.day:
                estimate_days = row.day
            else:
                estimate_days = prev_est_day + prev_supply
        
        # keep track of the current calcs for next row calcs
        prev_est_day = estimate_days
        prev_supply = row.supply
        prev_day = row.day

        res.append([item for item in row] + [estimate_days])

    return res

run the function on the RDD using flatMapValues() and extract the values. A sort on the day field is required, and a sorted() is used to sort the group of row-lists by the field (ok.day).

data_vals = data_sdf.rdd. \
    groupBy(lambda gk: 1). \
    flatMapValues(lambda r: estimateDaysCalc(sorted(r, key=lambda ok:ok.day))). \
    values()

create schema for the new values

data_schema = data_sdf. \
    withColumn('dropme', func.lit(None).cast('int')). \
    drop('dropme'). \
    schema. \
    add('estimate_days', 'integer')

create dataframe using the newly created values and schema

new_data_sdf = spark.createDataFrame(data_vals, data_schema)

new_data_sdf.show()

# +---+------+-------------+
# |day|supply|estimate_days|
# +---+------+-------------+
# |  1|     3|            1|
# |  3|     1|            4|
# |  9|     5|            9|
# | 10|     9|           14|
# | 11|     1|           23|
# +---+------+-------------+
samkart
  • 6,007
  • 2
  • 14
  • 29
  • Hi Samkart: Thank you so much for the answer! I tried your first approach, however, at "# create new field within the struct" "func.aggregate" , it produced error : "Module pyspark.sql.functions does not contain an attribute aggregate. Please review your code.". Then I replaced with "data2_sdf.agg", still error: "There is an incorrect call to a Column object in your code. Please review your code." Can you check this for me? Thanks! My Pyspark version is 3.0.2 – Bobby Jul 19 '22 at 20:00
  • @Bobby as mentioned in the answer, that approach is only available starting spark 3.1 – samkart Jul 20 '22 at 03:35