0

I have trouble using a trained Keras model in PySpark. The following versions of libraries are used:

tensorflow==1.1.0
h5py==2.7.0
keras==2.0.4

Also, I use Spark 2.4.0.

from pyspark.sql import SparkSession
import pyspark.sql.functions as func
from keras.models import load_model

spark = SparkSession \
    .builder \
    .appName("Test") \
    .master("local[2]") \
    .getOrCreate()

my_model = load_model("my_model.h5")
spark.sparkContext.addFile("my_model.h5")
my_model_bcast = spark.sparkContext.broadcast(my_model)

# ...

get_prediction_udf = func.udf(get_prediction, IntegerType())
ds = ds\
    .withColumn("predicted_value", get_prediction_udf(my_model_bcast,
                                                      func.col("col1"),
                                                      func.col("col2"))))

The function get_prediction looks as follows (a simplified code):

def get_prediction(my_model_bcast, col1, col2):
    cur_state = np.array([col1,col2])
    state = cur_state.reshape(1,2)
    ynew = my_model_bcast.predict(state)
    return np.argmax(ynew[0])

The following error is triggered by the line my_model_bcast = spark.sparkContext.broadcast(my_model):

  File "/usr/local/spark-2.4.0-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/broadcast.py", line 110, in dump
    pickle.dump(value, f, 2)
TypeError: can't pickle _thread.lock objects

I was reading similar threads in order to find a solution. As far as I understand, keras is not supporting applying pickle. But in this case how can I make predictions in PySpark using a trained model?

ScalaBoy
  • 3,254
  • 13
  • 46
  • 84

1 Answers1

1

It doesn't seem possible to serialise keras models, so maybe just distribute the file and as a spark file? So inside your function (where you expect the model as input) you can read file from that path and create model inside it?

path = SparkFiles.get("mode_file.h5")
model =  load_model(path)
0xc0de
  • 8,028
  • 5
  • 49
  • 75
  • Thank you. I did exactly what you suggested. I got the error `Caused by: net.razorvine.pickle.objects.ClassDictConstructor.construct`. – ScalaBoy Dec 10 '18 at 14:54