0

I am learning JAX from a PyTorch background.

I am used to saving serialized PyTorch models as .pt files then deploying them into another application for evaluation.

What is the standard way of doing this with JAX?

I looked at the Flax guide here: https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html

And it seems like there are so many options, with Flax, Orbax, etc. Is there a standard and proper way of simply saving a model, then loading it in another application for evaluation purposes?

1 Answers1

0

A model is not something that exists in jax a priori. Jax is a numerical computing library, and not a deep learning library specifically.

The deep learning libraries written on top of jax have each their own definition of a model, as the model is a compound of functions, parameters, and other properties. For example, you can have a flax model, a haiku model, or a trax model (and others). Each of these definitions defines its own (de)serialisation protocol, and depening on which one you use, the result and the procedure will of course be different.

However, most of these model definitions are PyTreeDef. You can read about PyTrees here, but you can think of a pytree as nothing more than a (registered) collection of jax arrays. For example, a tuple of arrays, or a dictionary of arrays is a PyTree. Most frameworks define a model as a PyTree together with some syntactic sugar. Being able to serialise a pytree is enough to be able to serialise a model.

As you mentioned flax, here is an example on how to (de)serialise a model in flax.

from typing import Any
import jax.numpy as jnp
from flax import struct
from flax.serialization import to_state_dict, from_state_dict


class Model(struct.PyTreeNode):
  params: Any
  forward: callable = struct.field(pytree_node=False)

  def __apply__(self, *args):
    return self.forward(*args)


params = jnp.ones(())
forward = lambda params, x: params * jnp.mean(x**2)
model = Model(params, forward)

# serialise
serialised_model = to_state_dict(model)
print("Serialised model", serialised_model)
# Outputs:
# Serialised model {'params': Array(1., dtype=float32)}

# deserialise
deserialised_model = from_state_dict(model, serialised_model)
print("Deserialised model", deserialised_model)
# Outputs:
# Deserialised model Model(params=Array(1., dtype=float32), forward=<function <lambda> at 0x7efcb4c30d30>)

Or you can use the orbax checkpoint API as you already mentioned. A clear example here.

If you want to load the model into another application (for example, an application in a different language) the application must implement the (de)serialisation protocol you use in python.