3

Background I use explode to transpose columns to rows. This works very well in general with good performance. The source dataframe (df_audit in below code) is dynamic so can contain different structure.

Problem Recently have incoming dataframe with very large number of columns (5 thousand) - the below code runs successfully but is very slow to run the line starting 'exploded'. Anyone faced similar problems? I could split up the dataframe to multiple dataframes (broken out by columns) or might there be better way? Or example code?

Example code

key_cols = ["cola", "colb", "colc"]

cols = [col for col in df_audit.columns if col not in key_cols]

exploded = explode(array([struct(lit(c).alias("key"), col(c).alias("val")) for c in cols])).alias("exploded")

df_audit =  df_audit.select(key_cols + [exploded]).select(key_cols + ["exploded.key", "exploded.val"])
B Mart
  • 51
  • 1
  • 6
  • i've seen that `stack` sql function gives a good performance as well. but i usually use your stated method (however, instead of explode i use the `inline` sql function which explodes as well as create n columns from the structs) -- I'm guessing the slowness is due to the large number of columns as each row becomes 5k rows. – samkart Aug 23 '22 at 15:53

3 Answers3

0

Both lit() and col() are for some reason quite slow when used in a loop. You can try instead with arrays_zip():

exploded = explode(
    arrays_zip(split(lit(','.join(cols)), ',').alias('key'), array(cols).alias('val'))
).alias('exploded')

In my quick test on 5k columns, this runs for ~6s vs. original ~25s.

bzu
  • 1,242
  • 1
  • 8
  • 14
0

Sharing some timings for bzu's approach and OP's approach based on colaboratory notebook.

cols = ['i'+str(i) for i in range(5000)]

# OP's method
%timeit func.array(*[func.struct(func.lit(k).alias('k'), func.col(k).alias('v')) for k in cols])
# 34.7 s ± 2.84 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

# bzu's method
%timeit func.arrays_zip(func.split(func.lit(','.join(cols)), ',').alias('k'), func.array(cols).alias('v'))
# 10.7 s ± 1.41 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
samkart
  • 6,007
  • 2
  • 14
  • 29
  • Thank you for your comments bzu & samkart. I have these functions imported but I get error ''Column' object is not callable' when i run test = arrays_zip(split(lit(','.join(cols)), ',').alias('key'), array(cols).alias('val')) ... any thoughts? I am using spark 3.1 – B Mart Aug 24 '22 at 18:26
  • @BMart it works fine for me -- spark 3.1.3 – samkart Aug 25 '22 at 11:16
  • i have provided example code below (in separate answer) which does not work. Could you let me know where i'm going wrong? Appreciate the help. – B Mart Aug 26 '22 at 14:35
  • The code you've sent also works on pyspark 3.3.0 when run directly and with spark-submit on 3.1.3. Did you try to run it in some other environment? Also, can you paste the full example code, including any initialization. – bzu Aug 26 '22 at 15:57
0

Thank you bzu & samkart but for some reason I cannot get the new line working. I have created a simple example that doesn't work as follows if you can see something obvious I am missing.

from pyspark.sql.functions import (
    array, arrays_zip, coalesce, col, explode, lit, lower, split, struct,substring,)
from pyspark.sql.types import StringType

def process_data():
    try:
        logger.info("\ntest 1")
        df_audit = spark.createDataFrame([("1", "foo", "abc", "xyz"),("2", "bar", "def", "zab"),],["id", "label", "colx", "coly"])

        logger.info("\ntest 2")
        key_cols = ["id", "label"]
        cols = [col for col in df_audit.columns if col not in key_cols]

        logger.info("\ntest 3")
        # exploded = explode(array([struct(lit(c).alias("key"), col(c).alias("val")) for c in cols])).alias("exploded")
        exploded = explode(arrays_zip(split(lit(','.join(cols)), ',').alias('key'), array(cols).alias('val'))).alias('exploded')

        logger.info("\ntest 4")
        df_audit =  df_audit.select(key_cols + [exploded]).select(key_cols + ["exploded.key", "exploded.val"])
        df_audit.show()
    except Exception as e:
        logger.error("Error in process_audit_data: {}".format(e))
        return False
    return True

When I call process_data function I get following logged: test 1 test 2 test 3 test 4 Error in process_audit_data: No such struct field key in 0, 1. Note: it does work successfully with the commented exploded line

Many thanks

B Mart
  • 51
  • 1
  • 6