0

This works perfectly:

def f_jax(x):
    return jnp.sin(jnp.cos(x))


f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(batch, _)"])
f_tf = tf.function(f_tf, autograph=False)
f_tf = f_tf.get_concrete_function(
    tf.TensorSpec(shape=(None, 2), dtype=tf.float32),
)
f_layer = hub.KerasLayer(f_tf)
x = tf.keras.layers.Input(shape=(2,), dtype=tf.float32)
y = f_layer(x)

model = tf.keras.Model(inputs=[x], outputs=[y])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
intepreter = tf.lite.Interpreter(model_content=tflite_model)
intepreter.allocate_tensors()
input_details = intepreter.get_input_details()
output_details = intepreter.get_output_details()
intepreter.set_tensor(input_details[0]["index"], np.array([[1.0, 0.0]], dtype=np.float32))
intepreter.invoke()
intepreter.get_tensor(output_details[0]["index"])

When I add a second parameter to f_jax, the call to z = f_layer(x, y) fails:

def f_jax(x, y):
    return jnp.sin(jnp.cos(x + y))


f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(batch, _)", "(batch, _)"])
f_tf = tf.function(f_tf, autograph=False)
f_tf = f_tf.get_concrete_function(
    tf.TensorSpec(shape=(None, 2), dtype=tf.float32),
    tf.TensorSpec(shape=(None, 2), dtype=tf.float32),
)
f_layer = hub.KerasLayer(f_tf)
x = tf.keras.layers.Input(shape=(2,), dtype=tf.float32)
y = tf.keras.layers.Input(shape=(2,), dtype=tf.float32)
try:
    z = f_layer([x, y])
except Exception as e:
    print(e)

model = tf.keras.Model(inputs=[x, y], outputs=[z])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
intepreter = tf.lite.Interpreter(model_content=tflite_model)
intepreter.allocate_tensors()
input_details = intepreter.get_input_details()
output_details = intepreter.get_output_details()
intepreter.set_tensor(input_details[0]["index"], np.array([[1.0, 0.0]], dtype=np.float32))
intepreter.set_tensor(input_details[1]["index"], np.array([[0.0, 1.0]], dtype=np.float32))
intepreter.invoke()
intepreter.get_tensor(output_details[0]["index"])

The exception:

Exception encountered when calling layer "keras_layer_100" (type KerasLayer).

in user code:

File "/home/myuser/.local/lib/python3.10/site-packages/tensorflow_hub/keras_layer.py",

line 234, in call * result = f()

TypeError: converted_fun_tf(arg1, arg2) missing required arguments: arg2.

Call arguments received by layer "keras_layer_100" (type KerasLayer): • inputs=['tf.Tensor(shape=(None, 2), dtype=float32)', 'tf.Tensor(shape=(None, 2), dtype=float32)'] • training=None

galah92
  • 3,621
  • 2
  • 29
  • 55

0 Answers0