1

In Python I've trained a TensorFlow LinearClassifier and saved it like:

model = tf.contrib.learn.LinearClassifier(feature_columns=columns)
model.fit(input_fn=train_input_fn, steps=100)
model.export_savedmodel(export_dir, parsing_serving_input_fn)

By using the TensorFlow Java API I am able to load this model in Java using:

model = SavedModelBundle.load(export_dir, "serve");

It seems I should be able to run the graph using something like

model.session().runner().feed(???, ???).fetch(???, ???).run()

but what variable names/data should I feed to/fetch from the graph to provide it features and to fetch the probabilities of the classes? The Java documentation is lacking this information as far as I can see.

1 Answers1

7

The names of the nodes to feed would depend on what parsing_serving_input_fn does, in particular they should be the names of the Tensor objects that are returned by parsing_serving_input_fn. The names of the nodes to fetch would depend on what you're predicting (arguments to model.predict() if using your model from Python).

That said, the TensorFlow saved model format does include the "signature" of the model (i.e., the names of all Tensors that can be fed or fetched) as metadata that can provide hints.

From Python you can load the saved model and list out its signature using something like:

with tf.Session() as sess:
  md = tf.saved_model.loader.load(sess, ['serve'], export_dir)
  sig = md.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  print(sig)

Which will print something like:

inputs {
  key: "inputs"
  value {
    name: "input_example_tensor:0"
    dtype: DT_STRING
    tensor_shape {
      dim {
        size: -1
      }
    }
  }
}
outputs {
  key: "scores"
  value {
    name: "linear/binary_logistic_head/predictions/probabilities:0"
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: -1
      }
      dim {
        size: 2
      }
    }
  }
}
method_name: "tensorflow/serving/classify"

Suggesting that what you want to do in Java is:

Tensor t = /* Tensor object to be fed */
model.session().runner().feed("input_example_tensor", t).fetch("linear/binary_logistic_head/predictions/probabilities").run()

You can also extract this information purely within Java if your program includes the generated Java code for TensorFlow protocol buffers (packaged in the org.tensorflow:proto artifact) using something like this:

// Same as tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
// in Python. Perhaps this should be an exported constant in TensorFlow's Java API.
final String DEFAULT_SERVING_SIGNATURE_DEF_KEY = "serving_default"; 

final SignatureDef sig =
      MetaGraphDef.parseFrom(model.metaGraphDef())
          .getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);

You will have to add:

import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;

Since the Java API and the saved-model-format are somewhat new, there is much room for improvement in the documentation.

Hope that helps.

ash
  • 6,681
  • 3
  • 18
  • 30
  • Thanks for the answer! This looks promising. However, what do I have to provide for the input_example_tensor? Consider, for example, the [TensorFlow Iris classification tutorial](https://www.tensorflow.org/get_started/tflearn): exporting that model results in the same signature as you provide (inputs, dtype: DT_STRING), but I need to somehow feed this model 4 numbers. – Jan Kuipers May 10 '17 at 10:44
  • As I understand it now, the model wants a serialized Example protocol buffer, but at this moment (1) the protocol buffer isn't available in Java and (2) creating Tensors with DataType String (which is needed for serialized Examples) is not supported yet. :( – Jan Kuipers May 11 '17 at 14:16
  • FYI: Protocol buffers are available in Java in the [org.tensorflow:proto](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/proto) maven artifact ([javadoc](http://javadoc.io/doc/org.tensorflow/proto/)) DataType.STRING tensors are supported for scalars (i.e., a single string), but not multi-dimensional arrays yet (https://github.com/tensorflow/tensorflow/issues/8531) Hope that helps. – ash May 11 '17 at 17:21
  • Thanks again for the feedback. It's good to know that the protos are available in Java too. Regarding Tensors with Strings: I need to feed a vector of Strings to input_example_tensor, right? So a String Scalar isn't helpful at the moment. Or can I work around this? – Jan Kuipers May 12 '17 at 07:45