2

I've seen this question and this one, but neither actually explain what is going on, nor offer a solution to the problem I'm facing.

The code below is a snippet from what I'm trying to do in a larger context. Basically, I'm creating an object that contains a tensorflow.keras model, I'm saving it to a file with pickle using a trick adapted from this answer. The actual class I'm working on has several other fields and methods, hence why I'd prefer to make it pickle-able and do so in a flexible manner. See the code below just to reproduce the problem minimally. ReproduceProblem.py:

import pickle
import numpy as np
import tempfile
import tensorflow as tf


def __getstate__(self):
    model_str = ""
    with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
        tf.keras.models.save_model(self, fd.name, overwrite=True)
        model_str = fd.read()
    d = {"model_str": model_str}
    return d


def __setstate__(self, state):
    with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
        fd.write(state["model_str"])
        fd.flush()
        model = tf.keras.models.load_model(fd.name)
    self.__dict__ = model.__dict__


class ContainsSequential:
    def __init__(self):
        self.other_field = "potato"
        self.model = tf.keras.models.Sequential()
        self.model.__getstate__ = lambda mdl=self.model: __getstate__(mdl)
        self.model.__setstate__ = __setstate__
        self.model.add(tf.keras.layers.Input(shape=(None, 3)))
        self.model.add(tf.keras.layers.LSTM(3, activation="relu", return_sequences=True))
        self.model.add(tf.keras.layers.Dense(3, activation="linear"))


# Now do the business:
tf.keras.backend.clear_session()
file_name = 'pickle_file.pckl'
instance = ContainsSequential()
instance.model.predict(np.random.rand(3, 1, 3))
print(instance.other_field)
with open(file_name, 'wb') as fid:
    pickle.dump(instance, fid)
with open(file_name, 'rb') as fid:
    restored_instance = pickle.load(fid)
print(restored_instance.other_field)
restored_instance.model.predict(np.random.rand(3, 1, 3))
print('Done')

While is does not fail on the line instance.model.predict(np.random.rand(3, 1, 3)) it does fail on the line restored_instance.model.predict(np.random.rand(3, 1, 3)), the error message is:

  File "<path>\ReproduceProblem.py", line 52, in <module>
    restored_instance.model.predict(np.random.rand(3, 1, 3))
  File "<path>\Python\Python39\lib\site-packages\keras\engine\training.py", line 1693, in predict
    if self.distribute_strategy._should_use_with_coordinator:  # pylint: disable=protected-access
  File "<path>\Python\Python39\lib\site-packages\keras\engine\training.py", line 716, in distribute_strategy
    return self._distribution_strategy or tf.distribute.get_strategy()
AttributeError: 'Sequential' object has no attribute '_distribution_strategy'

I don't have the slightest idea of what _distribution_strategy should be, but in my workflow, once I've saved the file I don't need to train it anymore, just use it to make predictions or consult other attributes of the class. I've tried setting it to Noneand adding more attributes, but with no success.

Mefitico
  • 816
  • 1
  • 12
  • 41

2 Answers2

1

It is a dangerous approach to redefine methods of a Tensorflow class like this:

self.model = tf.keras.models.Sequential()
self.model.__getstate__ = lambda mdl=self.model: __getstate__(mdl)
self.model.__setstate__ = __setstate__

I'd recommend to avoid that and redefine the __getstate__and __setstate__ methods of the custom class instead. Here is a working example:

import pickle
import numpy as np
import tempfile
import tensorflow as tf


class ContainsSequential:
    def __init__(self):
        self.other_field = "Potato"
        self.model = tf.keras.models.Sequential()
        self.model.add(tf.keras.layers.Input(shape=(None, 3)))
        self.model.add(tf.keras.layers.LSTM(3, activation="relu", return_sequences=True))
        self.model.add(tf.keras.layers.Dense(3, activation="linear"))
        
    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
            tf.keras.models.save_model(self.model, fd.name, overwrite=True)
            model_str = fd.read()
        d = {"model_str": model_str, "other_field": self.other_field}
        return d
    
    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=False) as fd:
            fd.write(state["model_str"])
            fd.flush()
            model = tf.keras.models.load_model(fd.name)
        self.model = model
        self.other_field = state["other_field"]

And a test:

tf.keras.backend.clear_session()
file_name = 'pickle_file.pkl'
instance = ContainsSequential()

rnd = np.random.rand(3, 1, 3)
print(1, instance.model.predict(rnd))
with open(file_name, 'wb') as fid:
    pickle.dump(instance, fid)
with open(file_name, 'rb') as fid:
    r_instance = pickle.load(fid)
print(2, r_instance.model.predict(rnd))
print(r_instance.other_field)
user1635327
  • 1,469
  • 3
  • 11
  • Interesting answer. While it solves the bug in my minimum working example it creates another problem (maybe the fault is on my part). I've added a field "other_field" to the class (as of course the class I'm actually working with has other methods and fields), then tried to print it. This causes an error. I'll edit my question to show it, maybe your question too if you'd accept it. but TL, DR, didn't solve my problem. – Mefitico Oct 02 '21 at 00:12
  • You'll need to add other fields to `__getstate__`and `__setstate__`. I've added this example to the answer. – user1635327 Oct 02 '21 at 06:30
1

Instead of using pickle to serialize/de-serialize tensorflow models, you should be using model.save('path/to/location') and keras.models.load_model(). This is the recommended practice and you can have a look at the documentation at https://www.tensorflow.org/guide/keras/save_and_serialize.

  • While I expect this answer to be helpful for people who stumble upon this question in the future, one of the issues I'm dealing with is that the object I'm trying to save has a lot more to it than just a Keras model, and all this other stuff should be pickled together with it (preferably). The trick in the question used to work on past version of Keras/tensorflow, hence there is a not so small code-base that assumes something similar should still work. – Mefitico Oct 04 '21 at 20:20