I am trying to log a Tensorflow.keras model using mlflow. In my workflow I need to define a custom .predict() . So based on this I have done these steps:
First I trained the tf.keras model as following:
_input = (
tf.data.Dataset.range(10)
.repeat()
.batch(1)
.batch(WINDOW_SIZE)
.batch(MINIBATCH_SIZE)
.map(lambda x: (x, x))
)
tf_model = Seq2Seq(WINDOW_SIZE=WINDOW_SIZE).get_model()
tf_model.compile(loss="mse", optimizer="adam")
tf_model.fit(_input, epochs=3, steps_per_epoch=3)
Once the model is trained, I have lodaded it using mlflow.pyfunc:
class NNModelWrapper(mlflow.pyfunc.PythonModel):
def __init__(
self, model, input_name, target_variance, ffill_limit, sample_rate, window_size
):
self.model = model
self.INPUT_NAME = input_name
self.TRANSFORMED_INPUT_NAME = f"zscore_{self.INPUT_NAME}"
self.TARGET_VARIANCE = target_variance
self.FFILL_LIMIT = ffill_limit
self.SAMPLE_RATE = sample_rate
self.WINDOW_SIZE = window_size
def predict(self, context, input_df):
df = input_df #apply preprocessing functions...
input_data = df[self.INPUT_NAME].values.flatten()
transformed_input_data = df[self.TRANSFORMED_INPUT_NAME].values.reshape(
-1, self.WINDOW_SIZE, 1
)
result = self.model.predict(transformed_input_data)
result = result.flatten()
result = result * np.sqrt(self.TARGET_VARIANCE)
result = np.where(result < 0, 0, result)
result = np.where(result > input_data, input_data, result)
return pd.DataFrame({"time": df.index, "inference": result}).set_index("time")
py_model = mlflow.pyfunc.load_model(mlflow.get_artifact_uri() + "/model")
wrapped_model = NNModelWrapper(
model=py_model,
input_name="input_P",
target_variance=10,
ffill_limit=180,
sample_rate="1T",
window_size=WINDOW_SIZE,
)
With this model wrapper I am able to get something when calling .predict() method. Everything fine here.
df = get_fake_timeseries(start="2021-01-01", end="2021-01-02")
result = wrapped_model.predict(context=None, input_df=df)
But the problem comes when I try to log the model using mlflow library again. This line of code gives the following output.
signature = infer_signature(df, result)
mlflow.pyfunc.log_model(
"wrapped_model", python_model=wrapped_model, signature=signature
)
...
File "/usr/lib/python3.6/pickle.py", line 821, in save_dict
self._batch_setitems(obj.items())
File "/usr/lib/python3.6/pickle.py", line 847, in _batch_setitems
save(v)
File "/usr/lib/python3.6/pickle.py", line 496, in save
rv = reduce(self.proto)
TypeError: can't pickle _thread.RLock objects
Any ideas on what could be doing wrong?