Here is an algorithm I worked out to do this:
EXAMPLE PROBLEM
Assume we want to sample 10 items from an RDD on 3 partitions like this:
- P1: ("A", 0.10), ("B", 0.10), ("C", 0.20)
- P2: ("D": 0.25), ("E", 0.25)
- P3: ("F", 0.10)
Here is the high-level algorithm:
INPUT: number of samples
and a RDD of items (with weights)
OUTPUT: dataset sample
on driver
- For each partition, calculate the total probability of sampling from the partition, and aggregate those values to the driver.
- This would give the probability distribution:
Prob(P1) = 0.40, Prob(P2) = 0.50, Prob(P3) = 0.10
- Generate a sample of the partitions (to determine the number of elements to select from each partition.)
- A sample may look like this:
[P1, P1, P1, P1, P2, P2, P2, P2, P2, P3]
- This would give us 4 items from P1, 5 items from P2, and 1 item from P3.
- On each separate partition, we locally generate a sample of the needed size using only the elements on that partition:
- On P1, we would sample 4 items with the (re-normalized) probability distribution:
Prob(A) = 0.25, Prob(B) = 0.25, Prob(C) = 0.50
. This could yield a sample such as [A, B, C, C]
.
- On P2, we would sample 5 items with probability distribution:
Prob(D) = 0.5, Prob(E) = 0.5
. This could yield a sample such as [D,D,E,E,E]
- On P3: sample 1 item with probability distribution:
P(F) = 1.0
, this would generate the sample [E]
Collect
the samples to the driver to yield your dataset sample [A,B,C,C,D,D,E,E,E,F]
.
Here is an implementation in scala:
case class Sample[T](weight: Double, obj: T)
/*
* Obtain a sample of size `numSamples` from an RDD `ar` using a two-phase distributed sampling approach.
*/
def sampleWeightedRDD[T:ClassTag](ar: RDD[Sample[T]], numSamples: Int)(implicit sc: SparkContext): Array[T] = {
// 1. Get total weight on each partition
var partitionWeights = ar.mapPartitionsWithIndex{case(partitionIndex, iter) => Array((partitionIndex, iter.map(_.weight).sum)).toIterator }.collect().toArray
//Normalize to 1.0
val Z = partitionWeights.map(_._2).sum
partitionWeights = partitionWeights.map{case(partitionIndex, weight) => (partitionIndex, weight/Z)}
// 2. Sample from partitions indexes to determine number of samples from each partition
val samplesPerIndex = sc.broadcast(sample[Int](partitionWeights, numSamples).groupBy(x => x).mapValues(_.size).toMap).value
// 3. On each partition, sample the number of elements needed for that partition
ar.mapPartitionsWithIndex{case(partitionIndex, iter) =>
val numSamplesForPartition = samplesPerIndex.getOrElse(partitionIndex, 0)
var ar = iter.map(x => (x.obj, x.weight)).toArray
//Normalize to 1.0
val Z = ar.map(x => x._2).sum
ar = ar.map{case(obj, weight) => (obj, weight/Z)}
sample(ar, numSamplesForPartition).toIterator
}.collect()
}
This code using a simple weighted sampling function sample
:
// a very simple weighted sampling function
def sample[T:ClassTag](dist: Array[(T, Double)], numSamples: Int): Array[T] = {
val probs = dist.zipWithIndex.map{case((elem,prob),idx) => (elem,prob,idx+1)}.sortBy(-_._2)
val cumulativeDist = probs.map(_._2).scanLeft(0.0)(_+_).drop(1)
(1 to numSamples).toArray.map(x => scala.util.Random.nextDouble).map{case(p) =>
def findElem(p: Double, cumulativeDist: Array[Double]): Int = {
for(i <- (0 until cumulativeDist.size-1))
if (p <= cumulativeDist(i)) return i
return cumulativeDist.size-1
}
probs(findElem(p, cumulativeDist))._1
}
}