I have a PySpark dataframe df
(type(df)
is pyspark.sql.dataframe.DataFrame
) and it has 4 columns. I'm trying to find the number of rows it has using df.count()
, but I keep getting the error messages below.
WARN PythonRunner: Detected deadlock while completing task 24.0 in stage 4 (TID 28): Attempting to kill Python Worker
...
ERROR Executor: Exception in task 24.0 in stage 4.0 (TID 28)
...
ValueError: Shape of passed values is (4,1), indices imply (4,4)
I read that the ValueError
usually means there are 4 rows of (4,1), but the examples I saw to resolve this are for pandas dataframe. I'm not sure how to resolve this for PySpark dataframe.
Also, should I be concerned about the deadlock
warning? Is it related to the ValueError
?
ETA: Added the code that I have before calling df.count()
. Basically, I'm trying to calculate SHAP values for my model based on the code in this article.
explainer = shap.TreeExplainer(model)
shap_columns = ['feature1', 'feature2', 'feature3', 'feature4']
def calculate_shap(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
for X in iterator:
yield pd.DataFrame(
explainer.shap_values(np.array(X), check_additivity=False)[0],
columns=shap_columns,
)
return_schema = StructType()
for feature in shap_columns:
return_schema = return_schema.add(StructField(feature, FloatType()))
df = spark_X.mapInPandas(calculate_shap, schema=return_schema)
df.count()
Both df
and spark_X
are of type pyspark.sql.dataframe.DataFrame
. df.printSchema()
showed the 4 columns correctly.
ETA2: Thanks to the suggestion by @samkart, I enclosed explainer.shap_values(np.array(X), check_additivity=False)[0]
with [explainer.shap_values(np.array(X), check_additivity=False)[0]]
, and there is no more error. But the number of rows returned is only 18K, while I'm expecting 180M. Why is this so?