I have a pre-trained JAX
model for MAXIM: Image Enhancement. Now to reduce the runtime and use it in production, I'll have to quantize the weights. I have 2 options since there is no direct conversion to ONNX.
- JAX -> Tensorflow -> ONNX (Help Thread)
- JAX -> TFLite
Going for the second option, there's this function tf.lite.TFLiteConverter.experimental_from_jax
Looking at this official example, the code block
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
f.write(tflite_model)
it seems to be using the params
from the model and a function predict
which in case are defined while model building and training itself as
predict:
init_random_params, predict = stax.serial(
stax.Flatten,
stax.Dense(1024), stax.Relu,
stax.Dense(1024), stax.Relu,
stax.Dense(10), stax.LogSoftmax)
and the params
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
params = get_params(opt_state)
My question is that how can I get these two required params
and predict
for my pre-trained model so that I can try to replicate example for my own model?