A model is not something that exists in jax a priori. Jax is a numerical computing library, and not a deep learning library specifically.
The deep learning libraries written on top of jax have each their own definition of a model, as the model is a compound of functions, parameters, and other properties. For example, you can have a flax
model, a haiku
model, or a trax
model (and others).
Each of these definitions defines its own (de)serialisation protocol, and depening on which one you use, the result and the procedure will of course be different.
However, most of these model definitions are PyTreeDef
. You can read about PyTrees here, but you can think of a pytree as nothing more than a (registered) collection of jax arrays.
For example, a tuple of arrays, or a dictionary of arrays is a PyTree.
Most frameworks define a model as a PyTree together with some syntactic sugar.
Being able to serialise a pytree is enough to be able to serialise a model.
As you mentioned flax
, here is an example on how to (de)serialise a model in flax.
from typing import Any
import jax.numpy as jnp
from flax import struct
from flax.serialization import to_state_dict, from_state_dict
class Model(struct.PyTreeNode):
params: Any
forward: callable = struct.field(pytree_node=False)
def __apply__(self, *args):
return self.forward(*args)
params = jnp.ones(())
forward = lambda params, x: params * jnp.mean(x**2)
model = Model(params, forward)
# serialise
serialised_model = to_state_dict(model)
print("Serialised model", serialised_model)
# Outputs:
# Serialised model {'params': Array(1., dtype=float32)}
# deserialise
deserialised_model = from_state_dict(model, serialised_model)
print("Deserialised model", deserialised_model)
# Outputs:
# Deserialised model Model(params=Array(1., dtype=float32), forward=<function <lambda> at 0x7efcb4c30d30>)
Or you can use the orbax
checkpoint API as you already mentioned.
A clear example here.
If you want to load the model into another application (for example, an application in a different language) the application must implement the (de)serialisation protocol you use in python.