0

I am using pyspark random forest classifier and like to create a pandas dataframe from predictions once I get them back. The strangest exception occurs when I try to do so. Here is my code:

random_forest = RandomForestClassifier(labelCol = 'label', featuresCol = 'features', maxDepth = 4, impurity = 'entropy', numTrees = 10, maxBins = 250)
rf_model = random_forest.fit(training_data)
predictions = rf_model.transform(test_data)

# Where exception happens
df = predictions.select('rawPrediction', 'label', 'prediction').where((predictions.label == '1.0') & (predictions.prediction == '0.0')).toPandas()

# The code that works fine
label_pred_train = predictions.select('label', 'prediction')
print label_pred_train.rdd.zipWithIndex().countByKey()

The problem happens when I am trying to filter predictions and select a subset of them to convert to a pandas dataframe. The same exception happens when I replace the toPandas with count, collect and etc. What surprises me the most is that when I remove that and execute the following lines where I am using rdd to count everything works fine and it returns the results. I've read several posts on how this is an issue with StringIndexer and how I could use handleInvalid = 'keep' but unfortunately I am running this using spark 2.1 and honestly do not think it has anything to do with StringIndexer since I'm able to do fit, transform and get predictions from the model. Is there anything that I might be missing here?

And here is the full exception:

py4j.protocol.Py4JJavaError: An error occurred while calling o1040.collectToPython.
: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$4: (string) => double)
    at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1072)
    at org.apache.spark.sql.catalyst.expressions.BinaryExpression.eval(Expression.scala:409)
    at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$.org$apache$spark$sql$catalyst$optimizer$EliminateOuterJoin$$canFilterOutNull(joins.scala:116)
    at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$7.apply(joins.scala:125)
    at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$7.apply(joins.scala:125)
    at scala.collection.LinearSeqOptimized$class.exists(LinearSeqOptimized.scala:93)
    at scala.collection.immutable.List.exists(List.scala:84)
    at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$.org$apache$spark$sql$catalyst$optimizer$EliminateOuterJoin$$buildNewJoinType(joins.scala:125)
    at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$apply$2.applyOrElse(joins.scala:140)
    at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$$anonfun$apply$2.applyOrElse(joins.scala:138)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:288)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:288)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:287)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:331)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:188)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:329)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:293)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:277)
    at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$.apply(joins.scala:138)
    at org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin$.apply(joins.scala:105)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:85)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:82)
    at scala.collection.IndexedSeqOptimized$class.foldl(IndexedSeqOptimized.scala:57)
    at scala.collection.IndexedSeqOptimized$class.foldLeft(IndexedSeqOptimized.scala:66)
    at scala.collection.mutable.WrappedArray.foldLeft(WrappedArray.scala:35)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:82)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:74)
    at scala.collection.immutable.List.foreach(List.scala:381)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:74)
    at org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:73)
    at org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:73)
    at org.apache.spark.sql.execution.QueryExecution$$anonfun$toString$2.apply(QueryExecution.scala:230)
    at org.apache.spark.sql.execution.QueryExecution$$anonfun$toString$2.apply(QueryExecution.scala:230)
    at org.apache.spark.sql.execution.QueryExecution.stringOrError(QueryExecution.scala:107)
    at org.apache.spark.sql.execution.QueryExecution.toString(QueryExecution.scala:230)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:54)
    at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2765)
    at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:2742)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:280)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:214)
    at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Unseen label: null.
    at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$4.apply(StringIndexer.scala:170)
    at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$4.apply(StringIndexer.scala:166)
    at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:89)
    at org.apache.spark.sql.catalyst.expressions.ScalaUDF$$anonfun$2.apply(ScalaUDF.scala:88)
    at org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1069)
    ... 75 more
ahajib
  • 12,838
  • 29
  • 79
  • 120
  • *"do not think it has anything to do with `StringIndexer` since I'm able to do fit, transform and get predictions from the model"* - this is not true. [Spark is lazy](https://stackoverflow.com/questions/38027877/spark-transformation-why-its-lazy-and-what-is-the-advantage). – pault Jan 28 '19 at 21:26
  • @pault well I agree with you on that, but as I have mentioned in my question, how come `countByKey` works fine and the collect does not? they are essentially going through all samples and I assume if there is any issue with the data, the same should happen when I am using rdd. – ahajib Jan 28 '19 at 21:33
  • 2
    I concur with @pault - the exception is thrown by the `StringIndexer`, and it is clearly indicated in the stack trace you've included. You can even find the exact line - [StringIndexer.scala:170](https://github.com/apache/spark/blob/4d2d3d47e00e78893b1ecd5a9a9070adc5243ac9/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala#L170) – 10465355 Jan 28 '19 at 22:51

1 Answers1

0

use this

df = predictions.select('rawPrediction', 'label', 'prediction').where((predictions.label == 1.0 & (predictions.prediction == 0.0)).toPandas()

instead of

df = predictions.select('rawPrediction', 'label', 'prediction').where((predictions.label == '1.0') & (predictions.prediction == '0.0')).toPandas()

Instead of using '1.0' use 1.0 ,most probably it is finding for String Type which is Integer Type and it is unable to find for which it assigns null and make error unseen label I guess.