0

I have a already written UDAF in scala using Spark2.4. Since our Databricks cluster was in 6.4 runtime whose support is no more there, we need to move to 7.3 LTS which have the long term support and uses Spark3. UDAF is deprecated in Spark3 and will be removed in future(most likely). So I am trying to convert a UDAF into Aggregator function

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{IntegerType,StringType, StructField, StructType, DataType}

object MaxCampaignIdAggregator extends UserDefinedAggregateFunction with java.io.Serializable{
  
  override def inputSchema: StructType = new StructType()
    .add("id", IntegerType, true)
    .add("name", StringType, true)

  def bufferSchema: StructType =  new StructType()
    .add("id", IntegerType, true)
    .add("name", StringType, true)

  // Returned Data Type .
  def dataType: DataType =  new StructType()
    .add("id", IntegerType, true)
    .add("name", StringType, true)

  // Self-explaining
  def deterministic: Boolean = true

  // This function is called whenever key changes
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = null
    buffer(1) = null
  }

  // Iterate over each entry of a group
  def update(buffer: MutableAggregationBuffer, inputRow: Row): Unit ={
      
      val inputId = inputRow.getAs[Int](0)
      val actualInputId = inputRow.get(0)
      val inputName = inputRow.getString(1)
      
      val bufferId = buffer.getAs[Int](0)
      val actualBufferId = buffer.get(0)
      val bufferName = buffer.getString(1)
      
      if(actualBufferId == null){
        buffer(0) = actualInputId
        buffer(1) = inputName
      }else if(actualInputId != null) {
        if(inputId > bufferId){
          buffer(0) = inputId
          buffer(1) = inputName
        }
      }  
  }

  // Merge two partial aggregates
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    
      val buffer1Id = buffer1.getAs[Int](0)
      val actualbuffer1Id = buffer1.get(0)
      val buffer1Name = buffer1.getString(1)
      
      val buffer2Id = buffer2.getAs[Int](0)
      val actualbuffer2Id = buffer2.get(0)
      val buffer2Name = buffer2.getString(1)
      
     if(actualbuffer1Id == null){
        buffer1(0) = actualbuffer2Id
        buffer1(1) = buffer2Name
     }else if(actualbuffer2Id != null){
        if(buffer2Id > buffer1Id){
          buffer1(0) = buffer2Id
          buffer1(1) = buffer2Name
        }
      }
    
  }

  // Called after all the entries are exhausted.
  def evaluate(buffer: Row): Any = {
    Row(buffer.get(0), buffer.getString(1))
  }
}

After usage this give output as :

{"id": 1282, "name": "McCormick Christmas"}

{"id": 1305, "name": "McCormick Perfect Pinch"}

{"id": 1677, "name": "Viking Cruises Viking Cruises"}

  • you should explain what you achieve to do with sample input/output data. If I interpret your UDAF correctly, you could also write this using standard-functions `.agg(max(struct($"id",$"name)))` – Raphael Roth May 21 '21 at 19:01
  • Yes the UDAF above was doing the same work as max(struct($"id",$"name)). I added the same in sparkSql and whole chunk of UDAF is no more needed. – Girish Rawat May 21 '21 at 20:52

0 Answers0