1

So, I've done enough research and haven't found a post that addresses what I want to do.

I have a PySpark DataFrame my_df which is sorted by value column-

+----+-----+                                                                    
|name|value|
+----+-----+
|   A|   30|
|   B|   25|
|   C|   20|
|   D|   18|
|   E|   18|
|   F|   15|
|   G|   10|
+----+-----+

The summation of all the counts in value column is equal to 136. I want to get all the rows whose combined values >= x% of 136. In this example, let's say x=80. Then target sum = 0.8*136 = 108.8. Hence, the new DataFrame will consist of all the rows that have a combined value >= 108.8.

In our example, this would come down to row D (since combined values upto D = 30+25+20+18 = 93).

However, the hard part is that I also want to include the immediately following rows with duplicate values. In this case, I also want to include row E since it has the same value as row D i.e. 18.

I want to slice my_df by giving a percentage x variable, for example 80 as discussed above. The new DataFrame should consist of the following rows-

+----+-----+                                                                    
|name|value|
+----+-----+
|   A|   30|
|   B|   25|
|   C|   20|
|   D|   18|
|   E|   18|
+----+-----+

One thing I could do here is iterate through the DataFrame (which is ~360k rows), but I guess that defeats the purpose of Spark.

Is there a concise function for what I want here?

kev
  • 2,741
  • 5
  • 22
  • 48

2 Answers2

3

Use pyspark SQL functions to do this concisely.

result = my_df.filter(my_df.value > target).select(my_df.name,my_df.value)
result.show()

Edit: Based on OP's question edit - Compute running sum and get rows until the target value is reached. Note that this will result in rows upto D, not E..which seems like a strange requirement.

from pyspark.sql import Window
from pyspark.sql import functions as f

# Total sum of all `values`
target = (my_df.agg(sum("value")).collect())[0][0]

w = Window.orderBy(my_df.name) #Ideally this should be a column that specifies ordering among rows
running_sum_df = my_df.withColumn('rsum',f.sum(my_df.value).over(w))
running_sum_df.filter(running_sum_df.rsum <= 0.8*target)
Gary Kerr
  • 13,650
  • 4
  • 48
  • 51
Vamsi Prabhala
  • 48,685
  • 4
  • 36
  • 58
  • I'm sorry I forgot to mention that I need to do some other operations as well, please check the update – kev Mar 27 '19 at 18:25
  • unfair to change the question entirely after it has been answered. – Vamsi Prabhala Mar 27 '19 at 18:27
  • I understand. Thanks for answering though. I still need to figure out a solution which takes care of the duplicate values. – kev Mar 27 '19 at 19:40
2

Your requirements are quite strict, so it's difficult to formulate an efficient solution to your problem. Nevertheless, here is one approach:

First calculate the cumulative sum and the total sum for the value column and filter the DataFrame using the percentage of target condition you specified. Let's call this result df_filtered:

import pyspark.sql.functions as f
from pyspark.sql import Window

w = Window.orderBy(f.col("value").desc(), "name").rangeBetween(Window.unboundedPreceding, 0)
target = 0.8

df_filtered = df.withColumn("cum_sum", f.sum("value").over(w))\
    .withColumn("total_sum", f.sum("value").over(Window.partitionBy()))\
    .where(f.col("cum_sum") <= f.col("total_sum")*target)

df_filtered.show()
#+----+-----+-------+---------+
#|name|value|cum_sum|total_sum|
#+----+-----+-------+---------+
#|   A|   30|     30|      136|
#|   B|   25|     55|      136|
#|   C|   20|     75|      136|
#|   D|   18|     93|      136|
#+----+-----+-------+---------+

Then join this filtered DataFrame back on the original on the value column. Since your DataFrame is already sorted by value, the final output will contain the rows you want.

df.alias("r")\
    .join(
    df_filtered.alias('l'),
    on="value"
).select("r.name", "r.value").sort(f.col("value").desc(), "name").show()
#+----+-----+
#|name|value|
#+----+-----+
#|   A|   30|
#|   B|   25|
#|   C|   20|
#|   D|   18|
#|   E|   18|
#+----+-----+

The total_sum and cum_sum columns are calculated using a Window function.

The Window w orders on the value column descending, followed by the name column. The name column is used to break ties- without it, both rows C and D would have the same cumulative sum of 111 = 75+18+18 and you'd incorrectly lose both of them in the filter.

w = Window\                                     # Define Window
    .orderBy(                                   # This will define ordering
        f.col("value").desc(),                  # First sort by value descending
        "name"                                  # Sort on name second
    )\
    .rangeBetween(Window.unboundedPreceding, 0) # Extend back to beginning of window

The rangeBetween(Window.unboundedPreceding, 0) specifies that the Window should include all rows before the current row (defined by the orderBy). This is what makes it a cumulative sum.

pault
  • 41,343
  • 15
  • 107
  • 149
  • imagine there is a row with `X 18` .. the `join` would also produce that row in the output. – Vamsi Prabhala Mar 27 '19 at 19:55
  • @VamsiPrabhala but OP said that the DataFrame is sorted by value descending (and anyway the Window function handles that). – pault Mar 27 '19 at 19:55
  • even when it is sorted, it doesn't matter right..`join` just gets all rows with that value. – Vamsi Prabhala Mar 27 '19 at 19:56
  • 1
    that's correct, but the way the question is formulated implies that all rows with the 18 need to be included. (unless I'm missing something) – pault Mar 27 '19 at 19:57
  • 1
    *following* implies a sort order. Since the DataFrame is sorted by value, the rows with the same value will **always** be immediately following. None of the higher values would appear elsewhere in the DataFrame. – pault Mar 27 '19 at 20:01
  • 1
    @VamsiPrabhala By immediately following, I did not mean just 1 row - it could be multiple rows with the same value. Since it's sorted bye `value`, they're bound to be next to each other. – kev Mar 27 '19 at 20:01
  • @pault is right, that's what I wanted and this solution does exactly that. Although, it'd be great if you could explain what `w = Window.orderBy(f.col("value").desc(), "name").rangeBetween(Window.unboundedPreceding, 0)` means. I'm new to this function. The `df` is already sorted by `value`, why are we ordering them again? – kev Mar 27 '19 at 20:01
  • right..i was under the assumption the df was ordered by name..in which case something like `(E,18) (F,19),(G,18) (H 20)` returns incorrect results with `join`..because `(G 18)` would be returned too. – Vamsi Prabhala Mar 27 '19 at 20:05
  • 1
    @kev I added an explanation of the Window with a link to a post that shows how to do a cumulative sum in pyspark. The DataFrame is *not* ordered- DataFrames are inherently unordered unless you explicitly specify an ordering. – pault Mar 27 '19 at 20:14
  • @pault Thanks a lot. One question though, since the `df` is already sorted by `value`, why are we ordering by `value` in `desc` again in the `Window` function? – kev Mar 27 '19 at 20:20
  • 1
    @kev do not think about Spark DataFrames as having any order. Even if it looks sorted, it's not guaranteed to be so under the hood. Spark distributes the data over multiple machines, which allows operations on them to occur in parallel. Since each executor does not have to worry about order, it can work on its part of the data independently. When you need an order, you have to specify *how* to sort and Spark will then shuffle the data between executors as needed. So the statement *`df` is already sorted by value* is false. – pault Mar 27 '19 at 20:26
  • @kev as a side note, shuffling is one of the most expensive steps in a spark job. You should try to avoid shuffles if possible- this is why I said it's hard to formulate an efficient solution to fit your requirements. – pault Mar 27 '19 at 20:31
  • @pault yeah you're right. if your solution takes too much time, I might remove the constraint of including all rows with same values in `value` column. – kev Mar 27 '19 at 20:39