9

I'm working on a pyspark routine to interpolate the missing values in a configuration table.

Imagine a table of configuration values that go from 0 to 50,000. The user specifies a few data points in between (say at 0, 50, 100, 500, 2000, 500000) and we interpolate the remainder. My solution mostly follows this blog post quite closely, except I'm not using any UDFs.

In troubleshooting the performance of this (takes ~3 minutes) I found that one particular window function is taking all of the time, and everything else I'm doing takes mere seconds.

Here is the main area of interest - where I use window functions to fill in the previous and next user-supplied configuration values:

from pyspark.sql import Window, functions as F

# Create partition windows that are required to generate new rows from the ones provided
win_last = Window.partitionBy('PORT_TYPE', 'loss_process').orderBy('rank').rowsBetween(Window.unboundedPreceding, 0)
win_next = Window.partitionBy('PORT_TYPE', 'loss_process').orderBy('rank').rowsBetween(0, Window.unboundedFollowing)

# Join back in the provided config table to populate the "known" scale factors
df_part1 = (df_scale_factors_template
  .join(df_users_config, ['PORT_TYPE', 'loss_process', 'rank'], 'leftouter')
  # Add computed columns that can lookup the prior config and next config for each missing value
  .withColumn('last_rank', F.last( F.col('rank'),         ignorenulls=True).over(win_last))
  .withColumn('last_sf',   F.last( F.col('scale_factor'), ignorenulls=True).over(win_last))
).cache()
debug_log_dataframe(df_part1 , 'df_part1') # Force a .count() and time Part1

df_part2 = (df_part1
  .withColumn('next_rank', F.first(F.col('rank'),         ignorenulls=True).over(win_next))
  .withColumn('next_sf',   F.first(F.col('scale_factor'), ignorenulls=True).over(win_next))
).cache()
debug_log_dataframe(df_part2 , 'df_part2') # Force a .count() and time Part2

df_part3 = (df_part2
  # Implements standard linear interpolation: y = y1 + ((y2-y1)/(x2-x1)) * (x-x1)
  .withColumn('scale_factor', 
              F.when(F.col('last_rank')==F.col('next_rank'), F.col('last_sf')) # Handle div/0 case
              .otherwise(F.col('last_sf') + ((F.col('next_sf')-F.col('last_sf'))/(F.col('next_rank')-F.col('last_rank'))) * (F.col('rank')-F.col('last_rank'))))
  .select('PORT_TYPE', 'loss_process', 'rank', 'scale_factor')
).cache()
debug_log_dataframe(df_part3, 'df_part3', explain: True) # Force a .count() and time Part3

The above used to be a single chained dataframe statement, but I've since split it into 3 parts so that I could isolate the part that's taking so long. The results are:

  • Part 1: Generated 8 columns and 300006 rows in 0.65 seconds
  • Part 2: Generated 10 columns and 300006 rows in 189.55 seconds
  • Part 3: Generated 4 columns and 300006 rows in 0.24 seconds

Why do my calls to first() over Window.unboundedFollowing take so much longer than last() over Window.unboundedPreceding?


Some notes to head off questions / concerns:

  • debug_log_dataframe is just a helper function to force the dataframe execution/cache with a .Count() and time it to yield the above logs.
  • We're actually operating on 6 config tables of 50001 rows at once (hence the partitioning and row count)
  • As a sanity check, I've ruled out the effects of cache() reuse by explicitly unpersist()ing before timing subsequent runs - I'm quite confident in the above measurements.

Physical Plan: To help answer this question, I call explain() on the result of part3 to confirm, among other things, that caching is having the desired effect. Here it is annotated to highlight the problem area: explain

The only differences I can see is that:

  • The first two calls (to last) show RunningWindowFunction, whereas the calls to next just read Window
  • Part 1 had a *(3) next to it, but Part 2 does not.

Some things I tried:

  • I tried further splitting part 2 into separate dataframes - the result is that each first statement takes half of the total time (~98 seconds)
  • I tried reversing the order in which I generate these columns (e.g. placing the calls to 'last' after the calls to 'first') but there's no difference. Whichever dataframe ends up containing the calls to first is the slow one.

I feel like I've done as much digging as I can and am kind of hoping a spark expert will take one look at know where this time is coming from.

Alain
  • 26,663
  • 20
  • 114
  • 184
  • I've heard more than one DBA make a blanket statement along the lines of "Avoid OLAP functions, they are more trouble than they're worth" and while I don't always follow that advice, seeing stuff like this definitely helps me understand where they're coming from. – Z4-tier Sep 24 '21 at 07:56
  • Spark 3.1.1 by the way, for anyone wondering if this is a version issue. – Alain Sep 24 '21 at 13:42
  • 1
    I created https://issues.apache.org/jira/browse/SPARK-36844 on the off-chance that this is unintended behaviour. – Alain Sep 24 '21 at 15:15

1 Answers1

8

The solution that doesn't answer the question

In trying various things to speed up my routine, it occurred to me to try re-rewriting my usages of first() to just be usages of last() with a reversed sort order.

So rewriting this:

win_next = (Window.partitionBy('PORT_TYPE', 'loss_process')
  .orderBy('rank').rowsBetween(0, Window.unboundedFollowing))

df_part2 = (df_part1
  .withColumn('next_rank', F.first(F.col('rank'),         ignorenulls=True).over(win_next))
  .withColumn('next_sf',   F.first(F.col('scale_factor'), ignorenulls=True).over(win_next))
)

As this:

win_next = (Window.partitionBy('PORT_TYPE', 'loss_process')
  .orderBy(F.desc('rank')).rowsBetween(Window.unboundedPreceding, 0))

df_part2 = (df_part1
  .withColumn('next_rank', F.last(F.col('rank'),         ignorenulls=True).over(win_next))
  .withColumn('next_sf',   F.last(F.col('scale_factor'), ignorenulls=True).over(win_next))
)

Much to my amazement, this actually solved the performance problem, and now the entire dataframe is generated in just 3 seconds. I'm pleased, but still vexed.

As I somewhat predicted, the query plan now includes a new SORT step before creating these next two columns, and they've changed from Window to RunningWindowFunction as the first two. Here's the new plan (without the code broken up into 3 separate cached parts anymore, because that was just to troubleshoot performance): enter image description here

As for the question:

Why do my calls to first() over Window.unboundedFollowing take so much longer than last() over Window.unboundedPreceding?

I'm hoping someone can still answer this, for academic reasons

Alain
  • 26,663
  • 20
  • 114
  • 184