3

I have a spark dataframe consist of two columns [Employee and Salary] where salary is in ascending order.

Sample Dataframe

Expected Output: 
| Employee |salary |
| -------- | ------|
|  Emp1    |  10   |
| Emp2     |  20   |
| Emp3     |  30   |
| EMp4     |  35   |
| Emp5     |  36   |
| Emp6     |  50   |
| Emp7     |  70   |

I want to group the rows such that each group has less than 80 as the aggregated value and assign a category to each group something like this. I will keep adding the salary in rows until the sum becomes more than 80. As soon as it becomes more than 80, I will asssign a new category.

Expected Output: 
| Employee |salary | Category|
| -------- | ------|----------
|  Emp1    |  10   |A        |
| Emp2     |  20   |A        |
| Emp3     |  30   |A        |
| EMp4     |  35   |B        |
| Emp5     |  36   |B        |
| Emp6     |  50   |C        |
| Emp7     |  70   |D        |

Is there a simple way we can do this in spark scala?

zmerr
  • 534
  • 3
  • 18
itisha
  • 47
  • 5
  • does `dataFrame.select($”Employee”, $”salary”, assignACategory($"Employee”, $”salary" ))` work for you? – zmerr Aug 16 '21 at 08:42
  • for the salary constraint, you can try `dataFrame.select($”Employee”, $”salary”, assignACategory($"Employee”, $”salary" )).filter($”salary” < 80)` – zmerr Aug 16 '21 at 08:58
  • @James I want to assign a category after adding row values Emp1 + Emp2 + Emp3 = 70 < 80 , So category A Then Em4 + Emp5 = 71 < 80 So category B Then Emp6 = 50 < 80 So category C And similarly last one D. I will keep on adding rows until the sum does not become 80. – itisha Aug 16 '21 at 09:01
  • Assuming the records get processed in order, I wonder if you can do this by writing a side-effecting `assignACategory` function inside an `object` which keeps track of the sum of the salaries it has read up to that point. Let’s say the `CategoryAssignor` object keeps a `private var sum = 0` and `private var lastCategory` and each time you call `assignACategory(salary)`, it adds the passed value to sum and decides which category to assign accordingly, resetting the `sum` each time the it exceeds 80 and updating the `lastCategory`. – zmerr Aug 16 '21 at 09:18
  • You can take a look at [this question](https://stackoverflow.com/questions/52949899/how-to-run-spark-job-sequentially) for sequential processing of rows. – zmerr Aug 16 '21 at 11:07
  • [this blog post](https://towardsdatascience.com/adding-sequential-ids-to-a-spark-dataframe-fa0df5566ff6) is also relevant to your requirement. – zmerr Aug 16 '21 at 11:10
  • 1
    You need to calculate the [cumulative SUM](https://stackoverflow.com/a/47879075/4808122) of the *salary* in the ascending order. Than simple integer divide by 70 and map to the category. – Marmite Bomber Aug 16 '21 at 11:28
  • [this question](https://stackoverflow.com/questions/63518521/doing-cumulative-sum-for-each-year-and-month-in-sparksql?noredirect=1&lq=1) is also related to cumulative sum. – zmerr Aug 16 '21 at 11:58
  • 1
    I don't think a cumulative sum would do the trick. Here is a counter example -> 40 50 60. We should obtain 40 -> A, 50 -> B, 60 -> C right? cumsum=40 90 150. If we divide by 80, we obtain 0, 1, 1 which would put 50 and 60 together. A cumulative sum would only work if we could assume the categories to all be full (i.e. with a salary sum exactly equal to 80). – Oli Aug 16 '21 at 12:23
  • @Oli minus 80 instead of dividing? -40, 10, 30 so if you create a new category each time the cumsum is above 0. and recalculate it. – zmerr Aug 16 '21 at 13:18

1 Answers1

1

To solve your problem, you can use a custom aggregate function over a window

First, you need to create your custom aggregate function. An aggregate function is defined by an accumulator (a buffer), that will be initialized (zero value) and updated when treating a new row (reduce function) or encountering another accumulator (merge function). And at the end, the accumulator is returned (finish function)

In your case, accumulator should keep two pieces of information:

  • Current category of employees
  • Sum of salaries of previous employees belonging to the current category

To store those information, you can use a Tuple (Int, Int), with first element is current category and second element the sum of salaries of previous employees of current category:

  • You initialize this tuple with (0, 0).
  • When you encounter a new row, if the sum of previous salaries and salary of current row is over 80, you increment category and reinitialize previous salaries' sum with salary of current row, else you add salary of current row to previous salaries' sum.
  • As you will be using a window function, you will sequentially treat rows so you don't need to implement merge with another accumulator.
  • And at the end, as you only want the category, you return only the first element of the accumulator.

So we get the following aggregator implementation:

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

object Labeler extends Aggregator[Int, (Int, Int), Int] {
  override def zero: (Int, Int) = (0, 0)

  override def reduce(catAndSum: (Int, Int), salary: Int): (Int, Int) = {
    if (catAndSum._2 + salary > 80)
      (catAndSum._1 + 1, salary)
    else
      (catAndSum._1, catAndSum._2 + salary)
  }

  override def merge(catAndSum1: (Int, Int), catAndSum2: (Int, Int)): (Int, Int) = {
    throw new NotImplementedError("should be used only over a windows function")
  }

  override def finish(catAndSum: (Int, Int)): Int = catAndSum._1

  override def bufferEncoder: Encoder[(Int, Int)] = Encoders.tuple(Encoders.scalaInt, Encoders.scalaInt)

  override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}

Once you have your aggregator, you transform it to a spark aggregate function using udaf function.

You then create your window over all dataframe and ordered by salary and apply your spark aggregate function over this window:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, udaf}

val labeler = udaf(Labeler)
val window = Window.orderBy("salary")

val result = dataframe.withColumn("category", labeler(col("salary")).over(window))

Using your example as input dataframe, you get the following result dataframe:

+--------+------+--------+
|employee|salary|category|
+--------+------+--------+
|Emp1    |10    |0       |
|Emp2    |20    |0       |
|Emp3    |30    |0       |
|Emp4    |35    |1       |
|Emp5    |36    |1       |
|Emp6    |50    |2       |
|Emp7    |70    |3       |
+--------+------+--------+
Vincent Doba
  • 4,343
  • 3
  • 22
  • 42