Using flax to create a network:
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
The input dimensions are [1, 28, 28, 1]
, in my custom training, I need to pass in input with different batch shapes such as [5, 28, 28, 1]
. How can I get this implemented for flax? In JAX you can use vmap
but not sure here.