I am trying to create a physics-informed neural network (PINN) in JAX. I want to differentiate the defined model (neural network) by the input (x). If I set model
to jax.grad(params)
, I get an error.
If I set model
to jax.grad(model)
, I don't get an error, but I don't know if I am able to differentiate the model of the neural network by x.
class MLP(fnn.Module):
@fnn.compact
def __call__(self, x):
x = fnn.Dense(128)(x)
x = fnn.relu(x)
x = fnn.Dense(256)(x)
x = fnn.relu(x)
x = fnn.Dense(10)(x)
return x
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1]))['params']
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)