0

Using flax to create a network:

def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

The input dimensions are [1, 28, 28, 1], in my custom training, I need to pass in input with different batch shapes such as [5, 28, 28, 1]. How can I get this implemented for flax? In JAX you can use vmap but not sure here.

MichaelMMeskhi
  • 659
  • 8
  • 26

0 Answers0