0

I'm trying to compare performance between R and Spark-ML and my initial testing tells me that Spark-ML is better than R in most cases and scales much better when the dataset gets bigger.

However, I'm having strange results when it comes to Gradient Boosted Trees, especially because R takes 3 minutes where Spark takes 15 on the same dataset, on the same computer.

Here is the R code:

train <- read.table("c:/Path/to/file.csv", header=T, sep=";",dec=".")
train$X1 <- factor(train$X1)
train$X2 <- factor(train$X2)
train$X3 <- factor(train$X3)
train$X4 <- factor(train$X4)
train$X5 <- factor(train$X5)
train$X6 <- factor(train$X6)
train$X7 <- factor(train$X7)
train$X8 <- factor(train$X8)
train$X9 <- factor(train$X9)

library(gbm)
boost <- gbm(Freq~X1+X2+X3+X4+X5+X6+X7+X8+X9+Y1, distribution = "gaussian", data = train, n.trees = 2000, bag.fraction = 1, shrinkY1 = 1, interaction.depth = 1, n.minobsinnode = 50, train.fraction = 1.0, cv.folds = 0, keep.data = TRUE)

And here is the scala code for Spark

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.regression.GBTRegressor

val conf = new SparkConf()
  .setAppName("GBTExample")
  .set("spark.driver.memory", "8g")
  .set("spark.executor.memory", "8g")
  .set("spark.network.timeout", "120s")
val sc = SparkContext.getOrCreate(conf.setMaster("local[8]"))
val spark = new SparkSession.Builder().getOrCreate()
import spark.implicits._

val sourceData = spark.read.format("com.databricks.spark.csv")
  .option("header", "true")
  .option("delimiter", ";")
  .option("inferSchema", "true")
  .load("c:/Path/to/file.csv")

val data = sourceData.select($"X1", $"X2", $"X3", $"X4", $"X5", $"X6", $"X7", $"X8", $"X9", $"Y1".cast("double"), $"Freq".cast("double"))

val X1Indexer = new StringIndexer().setInputCol("X1").setOutputCol("X1Idx")
val X2Indexer = new StringIndexer().setInputCol("X2").setOutputCol("X2Idx")
val X3Indexer = new StringIndexer().setInputCol("X3").setOutputCol("X3Idx")
val X4Indexer = new StringIndexer().setInputCol("X4").setOutputCol("X4Idx")
val X5Indexer = new StringIndexer().setInputCol("X5").setOutputCol("X5Idx")
val X6Indexer = new StringIndexer().setInputCol("X6").setOutputCol("X6Idx")
val X7Indexer = new StringIndexer().setInputCol("X7").setOutputCol("X7Idx")
val X8Indexer = new StringIndexer().setInputCol("X8").setOutputCol("X8Idx")
val X9Indexer = new StringIndexer().setInputCol("X9").setOutputCol("X9Idx")

val assembler = new VectorAssembler()
  .setInputCols(Array("X1Idx", "X2Idx", "X3Idx", "X4Idx", "X5Idx", "X6Idx", "X7Idx", "X8Idx", "X9Idx", "Y1"))
  .setOutputCol("features")

val dt = new GBTRegressor()
  .setLabelCol("Freq")
  .setFeaturesCol("features")
  .setImpurity("variance")
  .setMaxIter(2000)
  .setMinInstancesPerNode(50)
  .setMaxDepth(1)
  .setStepSize(1)
  .setSubsamplingRate(1)
  .setMaxBins(32)

val pipeline = new Pipeline()
  .setStages(Array(X1Indexer, X2Indexer, X3Indexer, X4Indexer, X5Indexer, X6Indexer, X7Indexer, X8Indexer, X9Indexer, assembler, dt))

val model = pipeline.fit(data)

I have the feeling that I'm not comparing the same methods here, but the documentation that I could find did not clarify the situation.

OBones
  • 310
  • 2
  • 13
  • Your "issue" falls into the same families as what I have stated in my answer here https://stackoverflow.com/questions/44866488/spark-ml-pipeline-with-randomforest-takes-too-long-on-20mb-dataset/44905657#44905657 – eliasah Jul 10 '17 at 09:41
  • Thanks for your comment, I understand that there is a setup cost associated to Spark which means that small datasets will not be processed efficiently. But I also tried it with a dataset containing 3 million rows, and I saw the same kind of duration difference. For 10 trees, maxDepth = 2, R takes less than 3 minutes while Spark takes 47 minutes. All other methods scale very well in Spark, but for GBT, it's as if it's waiting for something much of its time. When looking at CPU usage, I see short spikes every 2 minutes or so, and almost no usage in between. – OBones Jul 11 '17 at 10:17
  • Data skewness maybe an issue here. – eliasah Jul 11 '17 at 11:33

0 Answers0