I am using pyspark random forest classifier and like to create a pandas dataframe from predictions once I get them back. The strangest exception occurs when I try to do so. Here is my code:
random_forest = RandomForestClassifier(labelCol = 'label', featuresCol = 'features', maxDepth = 4, impurity = 'entropy', numTrees = 10, maxBins = 250)
rf_model = random_forest.fit(training_data)
predictions = rf_model.transform(test_data)
# Where exception happens
df = predictions.select('rawPrediction', 'label', 'prediction').where((predictions.label == '1.0') & (predictions.prediction == '0.0')).toPandas()
# The code that works fine
label_pred_train = predictions.select('label', 'prediction')
print label_pred_train.rdd.zipWithIndex().countByKey()
The problem happens when I am trying to filter predictions and select a subset of them to convert to a pandas dataframe. The same exception happens when I replace the toPandas
with count
, collect
and etc. What surprises me the most is that when I remove that and execute the following lines where I am using rdd to count everything works fine and it returns the results. I've read several posts on how this is an issue with StringIndexer
and how I could use handleInvalid = 'keep'
but unfortunately I am running this using spark 2.1 and honestly do not think it has anything to do with StringIndexer
since I'm able to do fit, transform and get predictions from the model. Is there anything that I might be missing here?
And here is the full exception:
py4j.protocol.Py4JJavaError: An error occurred while calling o1040.collectToPython.
: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$4: (string) => double)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1072)
at org.apache.spark.sql.catalyst.expressions.BinaryExpression.eval(Expression.scala:409)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$.org$apache$spark$sql$catalyst$optimizer$EliminateOuterJoin$$canFilterOutNull(joins.scala:116)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$7.apply(joins.scala:125)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$7.apply(joins.scala:125)
at scala.collection.LinearSeqOptimized$class.exists(LinearSeqOptimized.scala:93)
at scala.collection.immutable.List.exists(List.scala:84)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$.org$apache$spark$sql$catalyst$optimizer$EliminateOuterJoin$$buildNewJoinType(joins.scala:125)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$apply$2.applyOrElse(joins.scala:140)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$apply$2.applyOrElse(joins.scala:138)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:288)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:288)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:287)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$.apply(joins.scala:138)
at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$.apply(joins.scala:105)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:85)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:82)
at scala.collection.IndexedSeqOptimized$class.foldl(IndexedSeqOptimized.scala:57)
at scala.collection.IndexedSeqOptimized$class.foldLeft(IndexedSeqOptimized.scala:66)
at scala.collection.mutable.WrappedArray.foldLeft(WrappedArray.scala:35)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:82)
at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:74)
at scala.collection.immutable.List.foreach(List.scala:381)
at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:74)
at org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:73)
at org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:73)
at org.apache.spark.sql.execution.QueryExecution$$anonfun$toString$2.apply(QueryExecution.scala:230)
at org.apache.spark.sql.execution.QueryExecution$$anonfun$toString$2.apply(QueryExecution.scala:230)
at org.apache.spark.sql.execution.QueryExecution.stringOrError(QueryExecution.scala:107)
at org.apache.spark.sql.execution.QueryExecution.toString(QueryExecution.scala:230)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:54)
at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2765)
at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:2742)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:280)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:214)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Unseen label: null.
at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$4.apply(StringIndexer.scala:170)
at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$4.apply(StringIndexer.scala:166)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:89)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:88)
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1069)
... 75 more