The idea is to try to solve this problem per partition. This works great as long as each partition starts with a 3-digit number. If a partition doesn't have a 3-digit number we need to find the last 3-digit number of the previous partition.
First we need some utility functions:
The first (last3dig) finds the last 3-digit number of each partition. This will help us to have an initial 3-digit number for those partitions that don't start with one. If we apply this to each partition we get a list of such 3-digit numbers. Each element is the last 3-digit of its corresponding partition if it exists, else None.
The second (fillGaps) takes care of properly filling the gaps of the list of last 3-digits (from previous step). So if we have Some(1), None, None, Some(4)
it will make it Some(1), Some(1), Some(1), Some(4)
.
The third (trasnformRow) will go over the initial RDD and use the utilities functions we created to populate the resulting RDD.
// Finds the last 3-digit number in a Stream (RDD partition)
def last3dig(x: Stream[String]): Option[String] = {
def help(y: Stream[String], sofar: Option[String]): Option[String] = {
y match {
case h #:: tl => if(h.length == 3) help(tl, Some(h)) else help(tl, sofar)
case Stream.Empty => sofar
}
}
help(x, None)
}
def fillGaps(data: Vector[(Int, Option[String])] ): Vector[(Int, Option[String])] =
data.foldLeft(Vector.empty[(Int,Option[String])]){
case (col,n) => if(n._2.isEmpty) (n._1, col.head._2) +: col else n +: col
}
def trasnformRow(x: Stream[String], filler: String): Stream[String] = x match {
case h #:: tl =>
if( h.length != 3 )
s"$h $filler" #:: trasnformRow(tl, filler)
else
trasnformRow(tl, h) // update the filler
case Stream.Empty => x
}
// toy data
val d = Seq(765, 11111111, 22222222, 33333333, 456, 66666666, 88888888).map(_.toString)
val rdd = sc.makeRDD(d,2)
val mappings = rdd.mapPartitionsWithIndex {
case(i,iter) => Iterator( (i, last3dig(iter.toStream)) )
}.collect().toVector
val filledMappings = fillGaps(mappings)
val mm = sc.broadcast(filledMappings.toMap)
val finalResult = rdd.mapPartitionsWithIndex {
case (i,iter) => trasnformRow(iter.toStream, if(i>0) mm.value(i-1).get else "").toIterator
}.collect() // remove collect() for large dataset