3

I am facing an issue with spark streaming job where i am trying to use broadcast, mapWithState and checkpointing together in spark.

Following is the usage:

  • Since I have to pass some connection object (which is not Serializable) to the executors, I am using org.apache.spark.broadcast.Broadcast
  • Since we have to maintain some cached information i am using stateful streaming with mapWithState
  • Also I am using checkpointing of my streaming context

I also need to pass the broadcasted connection object into the mapWithState for fetching some data from an external source.

The flow is working just fine when the context is created newly. However when i crash the application and try to recover from checkpoint I get a ClassCastException.

I have put a small code snippet based on an example from asyncified.io to reproduce the issue in github:

  • My broadcast logic is yuvalitzchakov.utils.KafkaWriter.scala
  • The dummy logic of the application is yuvalitzchakov.stateful.SparkStatefulRunnerWithBroadcast.scala

Dummy snippet of the code:

val sparkConf = new SparkConf().setMaster("local[*]").setAppName("spark-stateful-example")

...
val prop = new Properties()
...

val config: Config = ConfigFactory.parseString(prop.toString)
val sc = new SparkContext(sparkConf)
val ssc = StreamingContext.getOrCreate(checkpointDir, () =>  {

    println("creating context newly")

    clearCheckpoint(checkpointDir)

    val streamingContext = new StreamingContext(sc, Milliseconds(batchDuration))
    streamingContext.checkpoint(checkpointDir)

    ...
    val kafkaWriter = SparkContext.getOrCreate().broadcast(kafkaErrorWriter)
    ...
    val stateSpec = StateSpec.function((key: Int, value: Option[UserEvent], state: State[UserSession]) =>
        updateUserEvents(key, value, state, kafkaWriter)).timeout(Minutes(jobConfig.getLong("timeoutInMinutes")))

    kafkaTextStream
    .transform(rdd => {
        offsetsQueue.enqueue(rdd.asInstanceOf[HasOffsetRanges].offsetRanges)
        rdd
    })
    .map(deserializeUserEvent)
    .filter(_ != UserEvent.empty)
    .mapWithState(stateSpec)
    .foreachRDD { rdd =>
        ...
        some logic
        ...

    streamingContext
    })
}

ssc.start()
ssc.awaitTermination()


def updateUserEvents(key: Int,
                     value: Option[UserEvent],
                     state: State[UserSession],
                     kafkaWriter: Broadcast[KafkaWriter]): Option[UserSession] = {

    ...
    kafkaWriter.value.someMethodCall()
    ...
}

I get the following error when

kafkaWriter.value.someMethodCall()

is executed:

17/08/01 21:20:38 ERROR Executor: Exception in task 2.0 in stage 3.0 (TID 4)
java.lang.ClassCastException: org.apache.spark.util.SerializableConfiguration cannot be cast to yuvalitzchakov.utils.KafkaWriter
    at yuvalitzchakov.stateful.SparkStatefulRunnerWithBroadcast$.updateUserSessions$1(SparkStatefulRunnerWithBroadcast.scala:144)
    at yuvalitzchakov.stateful.SparkStatefulRunnerWithBroadcast$.updateUserEvents(SparkStatefulRunnerWithBroadcast.scala:150)
    at yuvalitzchakov.stateful.SparkStatefulRunnerWithBroadcast$$anonfun$2.apply(SparkStatefulRunnerWithBroadcast.scala:78)
    at yuvalitzchakov.stateful.SparkStatefulRunnerWithBroadcast$$anonfun$2.apply(SparkStatefulRunnerWithBroadcast.scala:77)
    at org.apache.spark.streaming.StateSpec$$anonfun$1.apply(StateSpec.scala:181)
    at org.apache.spark.streaming.StateSpec$$anonfun$1.apply(StateSpec.scala:180)
    at org.apache.spark.streaming.rdd.MapWithStateRDDRecord$$anonfun$updateRecordWithData$1.apply(MapWithStateRDD.scala:57)
    at org.apache.spark.streaming.rdd.MapWithStateRDDRecord$$anonfun$updateRecordWithData$1.apply(MapWithStateRDD.scala:55)
    at scala.collection.Iterator$class.foreach(Iterator.scala:893)
    at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
    at org.apache.spark.streaming.rdd.MapWithStateRDDRecord$.updateRecordWithData(MapWithStateRDD.scala:55)
    at org.apache.spark.streaming.rdd.MapWithStateRDD.compute(MapWithStateRDD.scala:159)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
    at org.apache.spark.rdd.RDD$$anonfun$8.apply(RDD.scala:336)
    at org.apache.spark.rdd.RDD$$anonfun$8.apply(RDD.scala:334)
    at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1005)
    at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:996)
    at org.apache.spark.storage.BlockManager.doPut(BlockManager.scala:936)
    at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:996)
    at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:700)
    at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:334)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:285)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:99)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:322)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
    at java.lang.Thread.run(Thread.java:745)

Basically kafkaWriter is the broadcast variable and kafkaWriter.value should return us the broadcasted variable but it is returning SerializableCongiguration which is not getting casted to the desired object

Thanks in advance for help!

Saman
  • 53
  • 6
  • Why do you need `KafkaWriter` inside `mapWithState`? Would it be possible to create the call prior to updating your state? Something that could possibly run inside `mapPartitions` instead? BTW, your example seems to have a copy/paste error since some code is replicated two times. – Yuval Itzchakov Aug 01 '17 at 19:43
  • Thanks for the reply Yuval.This is dummy example just to reproduce the issue. In our real use case we have to fetch some data by making a jdbc call to db which we use to update the state. So we have to pass in the broadcast to the mapWithState. Also if you are referring to SparkStateRunner and SparkStateRunnerWithBroadcast as replicated, the previous one does not has broadcast passed into the mapWithState while the later one has. – Saman Aug 02 '17 at 03:35
  • I see. Did you consider calling the JDBC driver before the call to `mapWithState`? – Yuval Itzchakov Aug 02 '17 at 05:41
  • Based on some event we have to fetch data from external source. We do not want to make a db call always as these calls are expensive. Thus it would not make sense for us to do the jbdc call before mapWithState – Saman Aug 03 '17 at 05:08

1 Answers1

0

Broadcast variable cannot be used with MapwithState(transformation operations in general) if we need to recover from checkpoint directory in Spark streaming. It can only be used inside output operations in that case as it requires Spark context to lazily initialize the broadcast

class JavaWordBlacklist {

private static volatile Broadcast<List<String>> instance = null;

public static Broadcast<List<String>> getInstance(JavaSparkContext jsc) {
if (instance == null) {
synchronized (JavaWordBlacklist.class) {
if (instance == null)

{ List<String> wordBlacklist = Arrays.asList("a", "b", "c"); instance = jsc.broadcast(wordBlacklist); }

}
}
return instance;
}
}

class JavaDroppedWordsCounter {

private static volatile LongAccumulator instance = null;

public static LongAccumulator getInstance(JavaSparkContext jsc) {
if (instance == null) {
synchronized (JavaDroppedWordsCounter.class) {
if (instance == null)

{ instance = jsc.sc().longAccumulator("WordsInBlacklistCounter"); }

}
}
return instance;
}
}

wordCounts.foreachRDD((rdd, time) -> {
// Get or register the blacklist Broadcast
Broadcast<List<String>> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context()));
// Get or register the droppedWordsCounter Accumulator
LongAccumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context()));
// Use blacklist to drop words and use droppedWordsCounter to count them
String counts = rdd.filter(wordCount -> {
if (blacklist.value().contains(wordCount._1()))

{ droppedWordsCounter.add(wordCount._2()); return false; }

else

{ return true; }

}).collect().toString();
String output = "Counts at time " + time + " " + counts;
}
metsathya
  • 111
  • 1
  • 6