0

Here's my code, modified from this github code Github repo in which I'm trying to classify a set of images containing MRI scan for classifying them into cancer not cancer (0-1). As you can see in the below code I got an error after defining the pipeline and start fitting the model with the training dataset, I've also tried to get rid of path column but still not working.

from functools import reduce
from pyspark.sql.functions import lit
import pandas as pd
from PIL import Image
import numpy as np
import io
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.feature import VectorAssembler, VectorIndexer,StringIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import SparkSession
from pyspark.ml.functions import vector_to_array



def init_spark():
  spark = SparkSession.builder\
    .appName("ML_image-app")\
    .getOrCreate()
  sc = spark.sparkContext
  return spark,sc


spark,sc = init_spark()

# read in the files from the mounted storage as binary file
no_brain = spark.read.format("binaryFile") \
  .option("pathGlobFilter", "*.jpg") \
  .option("recursiveFileLookup", "true") \
  .load("/FileStore/tables/no")\
  .withColumn("label",lit(0))
  #.load('dbfs:/mnt/de/path_to_images')
yes_brain = spark.read.format("binaryFile") \
  .option("pathGlobFilter", "*.jpg") \
  .option("recursiveFileLookup", "true") \
  .load("/FileStore/tables/yes")\
  .withColumn("label",lit(1))
  #.load('dbfs:/mnt/de/path_to_images')

dfs = [no_brain,yes_brain]


images_df = reduce(lambda no_brain, yes_brain: no_brain.union(yes_brain.select(no_brain.columns)), dfs)


# select the base model, here I have used ResNet50
model = ResNet50(include_top=False)
#model.summary()  # verify that the top layer is removed

bc_model_weights = sc.broadcast(model.get_weights())

#declaring functions to execute on the worker nodes of the Spark cluster
def model_fn():
  """
  Returns a ResNet50 model with top layer removed and broadcasted pretrained weights.
  """
  model = ResNet50(weights=None, include_top=False)
  model.set_weights(bc_model_weights.value)
  return model

def preprocess(content):
  """
  Preprocesses raw image bytes for prediction.
  """
  img = Image.open(io.BytesIO(content)).resize([224, 224])
  arr = img_to_array(img)
  return preprocess_input(arr)

def featurize_series(model, content_series):
  """
  Featurize a pd.Series of raw images using the input model.
  :return: a pd.Series of image features
  """
  input = np.stack(content_series.map(preprocess))
  preds = model.predict(input)
  # For some layers, output features will be multi-dimensional tensors.
  # We flatten the feature tensors to vectors for easier storage in Spark DataFrames.
  output = [p.flatten() for p in preds]
  return pd.Series(output)


@pandas_udf('array<float>', PandasUDFType.SCALAR_ITER)
def featurize_udf(content_series_iter):
  '''
  This method is a Scalar Iterator pandas UDF wrapping our featurization function.
  The decorator specifies that this returns a Spark DataFrame column of type ArrayType(FloatType).
  
  :param content_series_iter: This argument is an iterator over batches of data, where each batch
                              is a pandas Series of image data.
  '''
  # With Scalar Iterator pandas UDFs, we can load the model once and then re-use it
  # for multiple data batches.  This amortizes the overhead of loading big models.
  model = model_fn()
  for content_series in content_series_iter:
    yield featurize_series(model, content_series)


# Pandas UDFs on large records (e.g., very large images) can run into Out Of Memory (OOM) errors.
# If you hit such errors in the cell below, try reducing the Arrow batch size via `maxRecordsPerBatch`.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")


# We can now run featurization on our entire Spark DataFrame.
# NOTE: This can take a long time (about 10 minutes) since it applies a large model to the full dataset.
features_df = images_df.repartition(16).select(col("path"), featurize_udf("content").alias("features"))


#MLLib needs some post processing of the features column format
list_to_vector_udf = udf(lambda l: Vectors.dense(l), VectorUDT())
features_df = features_df.select(
   col("path"),  
    list_to_vector_udf(features_df["features"]).alias("features")
)

#features_df = features_df.withColumn("features", vector_to_array("features"))

#dfte.csv("dbfs:/FileStore/tables/features_df")

f_y = features_df.filter(col("path").contains("Y")).withColumn("label",lit(0))
f_y=f_y.withColumn("label",lit(0))

f_n = features_df.filter(col("path").contains("N")).withColumn("label",lit(1))
f_n=f_n.withColumn("label",lit(0))

dfs = [f_y,f_n]


features_df = reduce(lambda f_n, f_y: f_n.union(f_y.select(f_n.columns)), dfs)
features_df = features_df.drop("path")


# splitting in to training, validate and test set
df_train_split, df_validate_split, df_test_split =  features_df.randomSplit([0.6, 0.3, 0.1],42)  


#Here we start to train the tail of the model

# This concatenates all feature columns into a single feature vector in a new column "featuresModel".
vectorAssembler = VectorAssembler(inputCols=['features'], outputCol="featuresModel")

labelIndexer = StringIndexer(inputCol="label", outputCol="indexedTarget").fit(features_df)

lr = LogisticRegression(maxIter=5, regParam=0.03, 
                        elasticNetParam=0.5, labelCol="indexedTarget", featuresCol="featuresModel")

# define a pipeline model
sparkdn = Pipeline(stages=[labelIndexer,vectorAssembler,lr])
spark_model = sparkdn.fit(df_train_split) # start fitting or training

# evaluating the model
predictions = spark_model.transform(df_test_split)

# Select example rows to display.
predictions.select("prediction", "indexedTarget", "features").show(5)

# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(
    labelCol="indexedTarget", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g" % (1.0 - accuracy))

I'm running it on a community edition Databricks spark cluster with tensorflow, (but it not works also on my local spark cluster) That's the pip list

Package                           Version
--------------------------------- --------------------
absl-py                           1.4.0
argon2-cffi                       20.1.0
astunparse                        1.6.3
async-generator                   1.10
attrs                             21.2.0
backcall                          0.2.0
backports.entry-points-selectable 1.1.1
black                             22.3.0
bleach                            4.0.0
boto3                             1.21.18
botocore                          1.24.18
cachetools                        5.2.1
certifi                           2021.10.8
cffi                              1.14.6
chardet                           4.0.0
charset-normalizer                2.0.4
click                             8.0.3
cryptography                      3.4.8
cycler                            0.10.0
Cython                            0.29.24
dbus-python                       1.2.16
debugpy                           1.4.1
decorator                         5.1.0
defusedxml                        0.7.1
distlib                           0.3.6
distro                            1.4.0
distro-info                       0.23ubuntu1
entrypoints                       0.3
facets-overview                   1.0.0
filelock                          3.8.0
flatbuffers                       23.1.4
gast                              0.4.0
google-auth                       2.16.0
google-auth-oauthlib              0.4.6
google-pasta                      0.2.0
grpcio                            1.51.1
h5py                              3.7.0
idna                              3.2
importlib-metadata                6.0.0
ipykernel                         6.12.1
ipython                           7.32.0
ipython-genutils                  0.2.0
ipywidgets                        7.7.0
jedi                              0.18.0
Jinja2                            2.11.3
jmespath                          0.10.0
joblib                            1.0.1
jsonschema                        3.2.0
jupyter-client                    6.1.12
jupyter-core                      4.8.1
jupyterlab-pygments               0.1.2
jupyterlab-widgets                1.0.0
keras                             2.11.0
kiwisolver                        1.3.1
libclang                          15.0.6.1
Markdown                          3.4.1
MarkupSafe                        2.1.2
matplotlib                        3.4.3
matplotlib-inline                 0.1.2
mistune                           0.8.4
mypy-extensions                   0.4.3
nbclient                          0.5.3
nbconvert                         6.1.0
nbformat                          5.1.3
nest-asyncio                      1.5.1
notebook                          6.4.5
numpy                             1.20.3
oauthlib                          3.2.2
opt-einsum                        3.3.0
packaging                         21.0
pandas                            1.3.4
pandocfilters                     1.4.3
parso                             0.8.2
pathspec                          0.9.0
patsy                             0.5.2
pexpect                           4.8.0
pickleshare                       0.7.5
Pillow                            8.4.0
pip                               21.2.4
platformdirs                      2.5.2
plotly                            5.9.0
prometheus-client                 0.11.0
prompt-toolkit                    3.0.20
protobuf                          3.19.6
psutil                            5.8.0
psycopg2                          2.9.3
ptyprocess                        0.7.0
pyarrow                           7.0.0
pyasn1                            0.4.8
pyasn1-modules                    0.2.8
pycparser                         2.20
Pygments                          2.10.0
PyGObject                         3.36.0
pyodbc                            4.0.31
pyparsing                         3.0.4
pyrsistent                        0.18.0
python-apt                        2.0.0+ubuntu0.20.4.8
python-dateutil                   2.8.2
pytz                              2021.3
pyzmq                             22.2.1
requests                          2.26.0
requests-oauthlib                 1.3.1
requests-unixsocket               0.2.0
rsa                               4.9
s3transfer                        0.5.2
scikit-learn                      0.24.2
scipy                             1.7.1
seaborn                           0.11.2
Send2Trash                        1.8.0
setuptools                        58.0.4
six                               1.16.0
ssh-import-id                     5.10
statsmodels                       0.12.2
tenacity                          8.0.1
tensorboard                       2.11.2
tensorboard-data-server           0.6.1
tensorboard-plugin-wit            1.8.1
tensorflow                        2.11.0
tensorflow-estimator              2.11.0
tensorflow-io-gcs-filesystem      0.29.0
termcolor                         2.2.0
terminado                         0.9.4
testpath                          0.5.0
threadpoolctl                     2.2.0
tokenize-rt                       4.2.1
tomli                             2.0.1
tornado                           6.1
traitlets                         5.1.0
typing-extensions                 3.10.0.2
unattended-upgrades               0.1
urllib3                           1.26.7
virtualenv                        20.8.0
wcwidth                           0.2.5
webencodings                      0.5.1
Werkzeug                          2.2.2
wheel                             0.37.0
widgetsnbextension                3.6.0
wrapt                             1.14.1
zipp                              3.11.0

But I'm getting the following error:


PythonException: An exception was thrown from a UDF: 'IndexError: index 1 is out of bounds for axis 2 with size 1', from , line 75. Full traceback below:
Traceback (most recent call last):
  File "", line 103, in featurize_udf
  File "", line 82, in featurize_series
  File "/databricks/python/lib/python3.9/site-packages/pandas/core/series.py", line 4161, in map
    new_values = super()._map_values(arg, na_action=na_action)
  File "/databricks/python/lib/python3.9/site-packages/pandas/core/base.py", line 870, in _map_values
    new_values = map_f(values, mapper)
  File "pandas/_libs/lib.pyx", line 2859, in pandas._libs.lib.map_infer
  File "", line 75, in preprocess
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/keras/applications/resnet.py", line 611, in preprocess_input
    return imagenet_utils.preprocess_input(
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/keras/applications/imagenet_utils.py", line 121, in preprocess_input
    return _preprocess_numpy_input(x, data_format=data_format, mode=mode)
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/keras/applications/imagenet_utils.py", line 241, in _preprocess_numpy_input
    x[..., 1] -= mean[1]
IndexError: index 1 is out of bounds for axis 2 with size 1
---------------------------------------------------------------------------
PythonException                           Traceback (most recent call last)
<command-2425500801898995> in <cell line: 161>()
    159 # define a pipeline model
    160 sparkdn = Pipeline(stages=[labelIndexer,vectorAssembler,lr])
--> 161 spark_model = sparkdn.fit(df_train_split) # start fitting or training

/databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/_pyspark.py in patched_method(self, *args, **kwargs)
     28             call_succeeded = False
     29             try:
---> 30                 result = original_method(self, *args, **kwargs)
     31                 call_succeeded = True
     32                 return result

/databricks/spark/python/pyspark/ml/base.py in fit(self, dataset, params)
    203                 return self.copy(params)._fit(dataset)
    204             else:
--> 205                 return self._fit(dataset)
    206         else:
    207             raise TypeError(

/databricks/spark/python/pyspark/ml/pipeline.py in _fit(self, dataset)
    130                 if isinstance(stage, Transformer):
    131                     transformers.append(stage)
--> 132                     dataset = stage.transform(dataset)
    133                 else:  # must be an Estimator
    134                     model = stage.fit(dataset)

/databricks/spark/python/pyspark/ml/base.py in transform(self, dataset, params)
    260                 return self.copy(params)._transform(dataset)
    261             else:
--> 262                 return self._transform(dataset)
    263         else:
    264             raise TypeError("Params must be a param map but got %s." % type(params))

/databricks/spark/python/pyspark/ml/wrapper.py in _transform(self, dataset)
    398 
    399         self._transfer_params_to_java()
--> 400         return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sparkSession)
    401 
    402 

/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1319 
   1320         answer = self.gateway_client.send_command(command)
-> 1321         return_value = get_return_value(
   1322             answer, self.gateway_client, self.target_id, self.name)
   1323 

/databricks/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
    200                 # Hide where the exception came from that shows a non-Pythonic
    201                 # JVM exception message.
--> 202                 raise converted from None
    203             else:
    204                 raise

I need help, any suggestion will be very appreciated

0 Answers0