0

I have a pre-trained JAX model for MAXIM: Image Enhancement. Now to reduce the runtime and use it in production, I'll have to quantize the weights. I have 2 options since there is no direct conversion to ONNX.

  1. JAX -> Tensorflow -> ONNX (Help Thread)
  2. JAX -> TFLite

Going for the second option, there's this function tf.lite.TFLiteConverter.experimental_from_jax

Looking at this official example, the code block

serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
  f.write(tflite_model)

it seems to be using the params from the model and a function predict which in case are defined while model building and training itself as

predict:

init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)

and the params

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
params = get_params(opt_state)

My question is that how can I get these two required params and predict for my pre-trained model so that I can try to replicate example for my own model?

Deshwal
  • 3,436
  • 4
  • 35
  • 94

1 Answers1

0

So I got an answer on the official repo. Here is the code:

import tensorflow as tf
from jax.experimental import jax2tf


def predict(input_img):
  '''
  Function to predict the output from the JAX model
  '''
  return model.apply({'params': flax.core.freeze(params)}, input_img)


tf_predict = tf.function(
    jax2tf.convert(predict, enable_xla=False),
    input_signature=[
        tf.TensorSpec(shape=[1, 704, 1024, 3], dtype=tf.float32, name='input_image')
    ],
    autograph=False)

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [tf_predict.get_concrete_function()], tf_predict)

converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
tflite_float_model = converter.convert()

with open('float_model.tflite', "wb") as f: f.write(tflite_float_model)


converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()

with open('./quantized.tflite', 'wb') as f: f.write(tflite_quantized_model)

You can now load and run the model easily using tf.lite.Interpreter

Deshwal
  • 3,436
  • 4
  • 35
  • 94