I have the following code which runs computes some metrics by cross-validation for a random forest classification.
def run(data:RDD[LabeledPoint], metric:String = "PR") = {
val cv_data:Array[(RDD[LabeledPoint], RDD[LabeledPoint])] = MLUtils.kFold(data, numFolds, 0)
val result : Array[(Double, Double)] = cv_data.par.map{case (training, validation) =>
training.persist(org.apache.spark.storage.StorageLevel.MEMORY_ONLY)
validation.persist(org.apache.spark.storage.StorageLevel.MEMORY_ONLY)
val res :ParArray[(Double, Double)] = CV_params.par.zipWithIndex.map { case (p,i) =>
// Training classifier
val model = RandomForest.trainClassifier(training, numClasses, categoricalFeaturesInfo, params(0).asInstanceOf[Int], params(3).asInstanceOf[String], params(4).asInstanceOf[String],
params(1).asInstanceOf[Int], params(2).asInstanceOf[Int])
// Prediction
val labelAndPreds:RDD[(Double, Double)] = model.predictWithLabels(validation)
// Metrics computation
val bcm = new BinaryClassificationMetrics(labelAndPreds)
(bcm.areaUnderROC() / numFolds, bcm.areaUnderPR() / numFolds)
}
training.unpersist()
validation.unpersist()
res
}.reduce((s1,s2) => s1.zip(s2).map(t => (t._1._1 + t._2._1, t._1._2 + t._2._2))).toArray
val cv_roc = result.map(_._1)
val cv_pr = result.map(_._2)
// Extract best params
val which_max = (metric match {
case "ROC" => cv_roc
case "PR" => cv_pr
case _ =>
logWarning("Metrics set to default one: PR")
cv_pr
}).zipWithIndex.maxBy(_._1)._2
best_values_array = CV_params(which_max)
CV_areaUnderROC = cv_roc
CV_areaUnderPR = cv_pr
}
}
val numTrees = Array(50)
val maxDepth = Array(30)
val maxBins = Array(100)
val featureSubsetStrategy = Array("sqrt")
val impurity = Array("gini")
val CV_params: Array[Array[Any]] = {
for (a <- numTrees; b <- maxDepth; c <- maxBins; d <- featureSubsetStrategy;
e <- impurityString) yield Array(a, b, c, d, e)
}
run(data, "PR")
It runs on a YARN cluster on 50 containers (26GB of memory in total). the data
parameter is an RDD[LabeledPoint]
. I use kryo serialization and a default level of parallelism of 1000.
For a low size of data
, it works but for my real data of size 600 000, I obtain the following error:
Exception in thread "dag-scheduler-event-loop" java.lang.OutOfMemoryError: Java heap space
at java.util.Arrays.copyOf(Arrays.java:2271)
at java.io.ByteArrayOutputStream.grow(ByteArrayOutputStream.java:113)
at java.io.ByteArrayOutputStream.ensureCapacity(ByteArrayOutputStream.java:93)
at java.io.ByteArrayOutputStream.write(ByteArrayOutputStream.java:140)
at java.io.ObjectOutputStream$BlockDataOutputStream.write(ObjectOutputStream.java:1841)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1533)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431)
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1177)
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1547)
at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1508)
at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1431)
I can't figure where the error comes from, because the total allocated memory (26GB) is much higher than the consumed one during the job (I have checked on the spark web UI).
Any help would be appreciated. Thank you!