This works perfectly:
def f_jax(x):
return jnp.sin(jnp.cos(x))
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(batch, _)"])
f_tf = tf.function(f_tf, autograph=False)
f_tf = f_tf.get_concrete_function(
tf.TensorSpec(shape=(None, 2), dtype=tf.float32),
)
f_layer = hub.KerasLayer(f_tf)
x = tf.keras.layers.Input(shape=(2,), dtype=tf.float32)
y = f_layer(x)
model = tf.keras.Model(inputs=[x], outputs=[y])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
intepreter = tf.lite.Interpreter(model_content=tflite_model)
intepreter.allocate_tensors()
input_details = intepreter.get_input_details()
output_details = intepreter.get_output_details()
intepreter.set_tensor(input_details[0]["index"], np.array([[1.0, 0.0]], dtype=np.float32))
intepreter.invoke()
intepreter.get_tensor(output_details[0]["index"])
When I add a second parameter to f_jax
, the call to z = f_layer(x, y)
fails:
def f_jax(x, y):
return jnp.sin(jnp.cos(x + y))
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(batch, _)", "(batch, _)"])
f_tf = tf.function(f_tf, autograph=False)
f_tf = f_tf.get_concrete_function(
tf.TensorSpec(shape=(None, 2), dtype=tf.float32),
tf.TensorSpec(shape=(None, 2), dtype=tf.float32),
)
f_layer = hub.KerasLayer(f_tf)
x = tf.keras.layers.Input(shape=(2,), dtype=tf.float32)
y = tf.keras.layers.Input(shape=(2,), dtype=tf.float32)
try:
z = f_layer([x, y])
except Exception as e:
print(e)
model = tf.keras.Model(inputs=[x, y], outputs=[z])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
intepreter = tf.lite.Interpreter(model_content=tflite_model)
intepreter.allocate_tensors()
input_details = intepreter.get_input_details()
output_details = intepreter.get_output_details()
intepreter.set_tensor(input_details[0]["index"], np.array([[1.0, 0.0]], dtype=np.float32))
intepreter.set_tensor(input_details[1]["index"], np.array([[0.0, 1.0]], dtype=np.float32))
intepreter.invoke()
intepreter.get_tensor(output_details[0]["index"])
The exception:
Exception encountered when calling layer "keras_layer_100" (type KerasLayer).
in user code:
File "/home/myuser/.local/lib/python3.10/site-packages/tensorflow_hub/keras_layer.py",
line 234, in call * result = f()
TypeError: converted_fun_tf(arg1, arg2) missing required arguments: arg2.
Call arguments received by layer "keras_layer_100" (type KerasLayer): • inputs=['tf.Tensor(shape=(None, 2), dtype=float32)', 'tf.Tensor(shape=(None, 2), dtype=float32)'] • training=None