a when().otherwise()
can be used here, chained with reduce()
(from functools
).
data_sdf = spark.range(3).toDF('num')
# +---+
# |num|
# +---+
# | 0|
# | 1|
# | 2|
# +---+
replace_dict = {
1: 'A',
2: 'B',
3: 'C'
}
We can use the reduce()
to chain the when()
statements.
when_statement = reduce(lambda x, y: x.when(func.col('num') == y, replace_dict[y]),
replace_dict.keys(),
func.when(func.col('num') == None, func.lit(None))
). \
otherwise(func.lit(None))
print(when_statement)
# Column<'CASE WHEN (num = NULL) THEN NULL WHEN (num = 1) THEN A WHEN (num = 2) THEN B WHEN (num = 3) THEN C ELSE NULL END'>
data_sdf. \
withColumn('replaced_vals', when_statement). \
show()
# +---+-------------+
# |num|replaced_vals|
# +---+-------------+
# | 0| null|
# | 1| A|
# | 2| B|
# +---+-------------+
reduce()
applies a function to an iterable recursively, and its signature is reduce(function, iterable[, initializer])
, meaning first the function which is our when()
statement, then comes the iterable or our dictionary keys which will be used to pull replacing values from the dictionary recursively. Last part is optional but important in this case - it is the initial value that is at the top of the chain. In this case, because we wanted a func.when().when()....otherwise()
, we passed the first func.when()
as an initial value and the rest of them will be chained recursively using the function.