Two ways to do it -
using your input data
data_sdf.show()
# +---+------+
# |day|supply|
# +---+------+
# | 1| 3|
# | 3| 1|
# | 9| 5|
# | 10| 9|
# | 11| 1|
# +---+------+
Spark does not retain a sort order like SAS data steps do. So, we will have to sort the array or list wherever required.
Using arrays, aggregate()
and lambda function
SPARK 3.1+
create day-supply structs and collect it to create an array of the said structs. The array_sort()
is used to order the array of structs by the day field (first element within struct). The aggregate()
takes in an initial value and applies a function to each element of the provided array. So, the initial value is the array's first struct, and the lambda function is applied to each of the remaining structs. The array_union()
is used to append the newly created array, after applying the lambda function, to the initial value recursively. Finally, the inline()
function is used to create separate columns from the newly created array of structs. Detailed explanation on its workings can be found in this answer.
data2_sdf = data_sdf. \
withColumn('day_supply_struct', func.struct(func.col('day'), func.col('supply'))). \
groupBy(func.lit(1).alias('group_field')). \
agg(func.array_sort(func.collect_list('day_supply_struct')).alias('ds_struct_arr'))
# +-----------+------------------------------------------+
# |group_field|ds_struct_arr |
# +-----------+------------------------------------------+
# |1 |[{1, 3}, {3, 1}, {9, 5}, {10, 9}, {11, 1}]|
# +-----------+------------------------------------------+
# create new field within the struct
data3_sdf = data2_sdf. \
withColumn('arr_struct_w_est_days',
func.aggregate(func.slice(func.col('ds_struct_arr'), 2, data_sdf.count()),
func.array(func.col('ds_struct_arr')[0].withField('estimate_days', func.col('ds_struct_arr')[0]['day'])),
lambda x, y: func.array_union(x,
func.array(y.withField('estimate_days',
func.when(func.element_at(x, -1)['estimate_days'] + func.element_at(x, -1)['supply'] <= y['day'], y['day']).
otherwise(func.element_at(x, -1)['estimate_days'] + func.element_at(x, -1)['supply'])
)
)
)
)
)
# +-----------+------------------------------------------+-----------------------------------------------------------+
# |group_field|ds_struct_arr |arr_struct_w_est_days |
# +-----------+------------------------------------------+-----------------------------------------------------------+
# |1 |[{1, 3}, {3, 1}, {9, 5}, {10, 9}, {11, 1}]|[{1, 3, 1}, {3, 1, 4}, {9, 5, 9}, {10, 9, 14}, {11, 1, 23}]|
# +-----------+------------------------------------------+-----------------------------------------------------------+
# create columns using the struct fields
data3_sdf. \
selectExpr('inline(arr_struct_w_est_days)'). \
show()
# +---+------+-------------+
# |day|supply|estimate_days|
# +---+------+-------------+
# | 1| 3| 1|
# | 3| 1| 4|
# | 9| 5| 9|
# | 10| 9| 14|
# | 11| 1| 23|
# +---+------+-------------+
using rdd, flatMapValues()
and python function
create a python function to calculate estimate days while keeping track of the previous calculated values. It takes in a group of row-lists (groupBy()
is used to identify the grouping) and creates a list of rows-lists.
# best to ship this function to all executors in case of huge datasets
def estimateDaysCalc(groupedRows):
res = []
frstRec = True
for row in groupedRows:
if frstRec:
frstRec = False
# the first day will have a static value
estimate_days = row.day
else:
if prev_est_day + prev_supply <= row.day:
estimate_days = row.day
else:
estimate_days = prev_est_day + prev_supply
# keep track of the current calcs for next row calcs
prev_est_day = estimate_days
prev_supply = row.supply
prev_day = row.day
res.append([item for item in row] + [estimate_days])
return res
run the function on the RDD using flatMapValues()
and extract the values. A sort on the day
field is required, and a sorted()
is used to sort the group of row-lists by the field (ok.day
).
data_vals = data_sdf.rdd. \
groupBy(lambda gk: 1). \
flatMapValues(lambda r: estimateDaysCalc(sorted(r, key=lambda ok:ok.day))). \
values()
create schema for the new values
data_schema = data_sdf. \
withColumn('dropme', func.lit(None).cast('int')). \
drop('dropme'). \
schema. \
add('estimate_days', 'integer')
create dataframe using the newly created values and schema
new_data_sdf = spark.createDataFrame(data_vals, data_schema)
new_data_sdf.show()
# +---+------+-------------+
# |day|supply|estimate_days|
# +---+------+-------------+
# | 1| 3| 1|
# | 3| 1| 4|
# | 9| 5| 9|
# | 10| 9| 14|
# | 11| 1| 23|
# +---+------+-------------+