1

I am trying to fit a logistic regression model for a data set with 470 features and 10 million training instances. Here is a snippet of my code.

from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import RFormula

formula = RFormula(formula = "label ~ .-classWeight")


bestregLambdaVal = 0.005
bestregAlphaVal = 0.01

lr = LogisticRegression(maxIter=1000, regParam=bestregLambdaVal, elasticNetParam=bestregAlphaVal,weightCol="classWeight") 
pipeLineLr = Pipeline(stages = [formula, lr])
pipeLineFit = pipeLineLr.fit(mySparkDataFrame[featureColumnNameList + ['classWeight','label']])

I have also created a checkpoint directory,

sc.setCheckpointDir('checkpoint/')

as suggested here: Spark gives a StackOverflowError when training using ALS

However I get an error and here is a partial trace:

File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/base.py", line 64, in fit
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/pipeline.py", line 108, in _fit
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/base.py", line 64, in fit
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/wrapper.py", line 265, in _fit
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/wrapper.py", line 262, in _fit_java
  File "/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in __call__
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py", line 63, in deco
  File "/usr/lib/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 319, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o383361.fit.
: java.lang.StackOverflowError
    at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1189)
    at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548)
    at java.io.ObjectOutputStream.writeSerialData(ObjectOutputStream.java:1509)
    at java.io.ObjectOutputStream.writeOrdinaryObject(ObjectOutputStream.java:1432)
    at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1178)
    at java.io.ObjectOutputStream.writeObject(ObjectOutputStream.java:348)
    at scala.collection.immutable.List$SerializationProxy.writeObject(List.scala:468)
    at sun.reflect.GeneratedMethodAccessor11.invoke(Unknown Source)

I would also like to note that the 470 feature columns were iteratively added to spark data frame using withcolumn().

Srikant Chari
  • 193
  • 3
  • 13

1 Answers1

0

So the mistake I was making is that, when checkpointing the dataframe, I would only do:

mySparkDataFrame.checkpoint(eager=True)

The right was to do:

mySparkDataFrame = mySparkDataFrame.checkpoint(eager=True)

This is based on another question I had asked (and got an answer for) here:

pyspark rdd isCheckPointed() is false

Also, it is recommended to persist() the dataframe before checkpoint and also to count() it after the checkpoint

Srikant Chari
  • 193
  • 3
  • 13