Consider the following DataFrame:
+------+-----------------------+
|type |names |
+------+-----------------------+
|person|[john, sam, jane] |
|pet |[whiskers, rover, fido]|
+------+-----------------------+
Which can be created with the following code:
import pyspark.sql.functions as f
data = [
('person', ['john', 'sam', 'jane']),
('pet', ['whiskers', 'rover', 'fido'])
]
df = sqlCtx.createDataFrame(data, ["type", "names"])
df.show(truncate=False)
Is there a way to directly modify the ArrayType()
column "names"
by applying a function to each element, without using a udf
?
For example, suppose I wanted to apply the function foo
to the "names"
column. (I will use the example where foo
is str.upper
just for illustrative purposes, but my question is regarding any valid function that can be applied to the elements of an iterable.)
foo = lambda x: x.upper() # defining it as str.upper as an example
df.withColumn('X', [foo(x) for x in f.col("names")]).show()
TypeError: Column is not iterable
I could do this using a udf
:
foo_udf = f.udf(lambda row: [foo(x) for x in row], ArrayType(StringType()))
df.withColumn('names', foo_udf(f.col('names'))).show(truncate=False)
#+------+-----------------------+
#|type |names |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE] |
#|pet |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+
In this specific example, I could avoid the udf
by exploding the column, call pyspark.sql.functions.upper()
, and then groupBy
and collect_list
:
df.select('type', f.explode('names').alias('name'))\
.withColumn('name', f.upper(f.col('name')))\
.groupBy('type')\
.agg(f.collect_list('name').alias('names'))\
.show(truncate=False)
#+------+-----------------------+
#|type |names |
#+------+-----------------------+
#|person|[JOHN, SAM, JANE] |
#|pet |[WHISKERS, ROVER, FIDO]|
#+------+-----------------------+
But this is a lot of code to do something simple. Is there is a more direct way to iterate over the elements of an ArrayType()
using spark-dataframe functions?