0

I'm trying to do something fairly simple in tensorflow-lite, but I'm not sure it's possible.

I want to define a stateful graph where the shape of the state variable is defined when model is LOADED, not when it's saved.

As a simple example - lets say I just want to compute a temporal difference - ie. a graph that returns the difference between the input in two consecutive calls. The following should pass:

func = load_tflite_model_func(tflite_model_file_path)
runtime_shape = 60, 80
rng = np.random.RandomState(1234)
ims = [rng.randn(*runtime_shape).astype(np.float32) for _ in range(3)]
assert np.allclose(func(ims[0]), ims[0])
assert np.allclose(func(ims[1]), ims[1]-ims[0])
assert np.allclose(func(ims[2]), ims[2]-ims[1])

Now, to create and save the model, I do:

@dataclass
class TimeDelta(tf.Module):
    _last_val: Optional[tf.Tensor] = None
    def compute_delta(self, arr: tf.Tensor):
        if self._last_val is None:
            self._last_val = tf.Variable(tf.zeros(tf.shape(arr)))
        delta = arr-self._last_val
        self._last_val.assign(arr)
        return delta

compile_time_shape = 30, 40
# compile_time_shape = None, None  # Causes UnliftableError
tflite_model_file_path = tempfile.mktemp()
delta = TimeDelta()
save_signatures_to_tflite_model(
    {'delta': tf.function(delta.compute_delta, input_signature=[tf.TensorSpec(shape=compile_time_shape)])},
    path=tflite_model_file_path,
    parent_object=delta
)

The problem of course is that if my compile-time shape differs from my run-time shape, it crashes. Attempting to make the graph dynamically-shaped with compile_time_shape = None, None also fails, causing an UnliftableError when I try to save the graph (because it needs concrete dimensions for the variable).

A full Colab-Notebook demonstrating the problem is here.

So, to summarise - the question is:

How can I save a stateful graph in tflite, where the shape of the state of the graph depends on the shape of the input?

Corresponding tensor flow issue at https://github.com/tensorflow/tensorflow/issues/59217

Peter
  • 12,274
  • 9
  • 71
  • 86

1 Answers1

0

Ok well I found a solution that isn't ideal but does the job: Make the variable the largest size it could conceivably be at runtime, and just take and assign a slice of it. Here is a modified notebook that does that.

A downside of this approach is that you end up with very large tflite-files (in my case, 24MB of zeros).

Peter
  • 12,274
  • 9
  • 71
  • 86