1

I am trying to use a custom HashMap implementation as UserDefinedType instead of MapType in spark. The code is working fine in spark 1.5.2 but giving java.lang.ClassCastException: scala.collection.immutable.HashMap$HashMap1 cannot be cast to org.apache.spark.sql.catalyst.util.MapData exception in spark 1.6.2

The code :-

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

import scala.collection.immutable.HashMap

class Test extends UserDefinedAggregateFunction {

  def inputSchema: StructType =
    StructType(Array(StructField("input", StringType)))

  def bufferSchema = StructType(Array(StructField("top_n", CustomHashMapType)))

  def dataType: DataType = CustomHashMapType

  def deterministic = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = HashMap.empty[String, Long]
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val buff0 = buffer.getAs[HashMap[String, Long]](0)
    buffer(0) = buff0.updated("test", buff0.getOrElse("test", 0L) + 1L)
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    buffer1(0) = buffer1.
      getAs[HashMap[String, Long]](0)
      .merged(buffer2.getAs[HashMap[String, Long]](0))({ case ((k, v1), (_, v2)) => (k, v1 + v2) })
  }

  def evaluate(buffer: Row): Any = {
    buffer(0)
  }
}

private case object CustomHashMapType extends UserDefinedType[HashMap[String, Long]] {

  override def sqlType: DataType = MapType(StringType, LongType)

  override def serialize(obj: Any): Map[String, Long] =
    obj.asInstanceOf[Map[String, Long]]

  override def deserialize(datum: Any): HashMap[String, Long] = {
    datum.asInstanceOf[Map[String, Long]] ++: HashMap.empty[String, Long]
  }

  override def userClass: Class[HashMap[String, Long]] = classOf[HashMap[String, Long]]

}

The wrapper Class to run the UDAF:-

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

object TestJob {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local[4]").setAppName("DataStatsExecution")
    val sc = new SparkContext(conf)

    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._
    val df = sc.parallelize(Seq(1,2,3,4)).toDF("col")
    val udaf = new Test()
    val outdf = df.agg(udaf(df("col")))
    outdf.show
  }
}

When I run the above code in spark 1.6.2, I get the following exception:-

Caused by: java.lang.ClassCastException: scala.collection.immutable.HashMap$HashMap1 cannot be cast to org.apache.spark.sql.catalyst.util.MapData
    at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getMap(rows.scala:50)
    at org.apache.spark.sql.catalyst.expressions.GenericMutableRow.getMap(rows.scala:248)
    at org.apache.spark.sql.catalyst.expressions.JoinedRow.getMap(JoinedRow.scala:115)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source)
    at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$31.apply(AggregationIterator.scala:345)
    at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$31.apply(AggregationIterator.scala:344)
    at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:154)
    at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
    at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)
    at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
    at org.apache.spark.scheduler.Task.run(Task.scala:89)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:227)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

I have found that the HashMap implementation is way faster than the available spark MapType implementation. Are there any changes that can be done to run the code in spark 1.6.2 or is there any possible alternative?

Izhar Ahmed
  • 185
  • 8

0 Answers0