11

I have a Keras (sequential) model that could be saved with custom signature defs in Tensorflow 1.13 as follows:

from tensorflow.saved_model.utils import build_tensor_info
from tensorflow.saved_model.signature_def_utils import predict_signature_def, build_signature_def

model = Sequential() // with some layers

builder = tf.saved_model.builder.SavedModelBuilder(export_path)

score_signature = predict_signature_def(
    inputs={'waveform': model.input},
    outputs={'scores': model.output})

metadata = build_signature_def(
    outputs={'other_variable': build_tensor_info(tf.constant(1234, dtype=tf.int64))})

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  builder.add_meta_graph_and_variables(
      sess=sess,
      tags=[tf.saved_model.tag_constants.SERVING],
      signature_def_map={'score': score_signature, 'metadata': metadata})
  builder.save()

Migrating the model to TF2 keras was cool :), but I can't figure out how to save the model with the same signature as above. Should I be using the new tf.saved_model.save() or tf.keras.experimental.export_saved_model()? How should the above code be written in TF2?

Key requirements:

  • The model has a score signature and a metadata signature
  • The metadata signature contains 1 or more constants
Antony Harfield
  • 850
  • 1
  • 7
  • 16

1 Answers1

15

The solution is to create a tf.Module with functions for each signature definition:

class MyModule(tf.Module):
  def __init__(self, model, other_variable):
    self.model = model
    self._other_variable = other_variable

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32)])
  def score(self, waveform):
    result = self.model(waveform)
    return { "scores": results }

  @tf.function(input_signature=[])
  def metadata(self):
    return { "other_variable": self._other_variable }

And then save the module (not the model):

module = MyModule(model, 1234)
tf.saved_model.save(module, export_path, signatures={ "score": module.score, "metadata": module.metadata })

Tested with Keras model on TF2.

Antony Harfield
  • 850
  • 1
  • 7
  • 16
  • 1
    Extremely helpful answer! I scoured the internet for a while before finding it. Also of interest: I tested it out, and you don't even need to wrap in a tf.Module for a custom Keras Model subclass. You can just add the metadata function as you've written it and everything works, along with tf.keras.models.save/load Wrote up here: https://stackoverflow.com/questions/54642590/add-metadata-to-tensorflow-frozen-graph-pb – Carson McNeil Aug 25 '20 at 02:17
  • I am training a text classification model in tf2. I am using python logic to build up vocab dicts, i.e. item2idx and idx2item which are present in a dataset class object. I tried and failed to export that as the other_variable in place of simple 1234 that you used. Any idea what might be the issue here? – n0obcoder Oct 14 '21 at 17:39