0

I am trying to create a physics-informed neural network (PINN) in JAX. I want to differentiate the defined model (neural network) by the input (x). If I set model to jax.grad(params), I get an error.
If I set model to jax.grad(model), I don't get an error, but I don't know if I am able to differentiate the model of the neural network by x.

class MLP(fnn.Module):
    @fnn.compact
    def __call__(self, x):
        x = fnn.Dense(128)(x)
        x = fnn.relu(x)
        x = fnn.Dense(256)(x)
        x = fnn.relu(x)
        x = fnn.Dense(10)(x)
        return x

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1]))['params']
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
joel
  • 6,359
  • 2
  • 30
  • 55
hohohohoho
  • 39
  • 1
  • 4

1 Answers1

0

You can differentiate a model in JAX by (1) defining a function that you want to differentiate, (2) transforming it with jax.grad, jax.jacrev, jax.jacfwd, etc. as appropriate for your application, and (3) passing data to the transformed function.

It's not entirely clear from your question what operation you're hoping to differentiate, but here is an example of computing a forward-mode jacobian of the training state creation with respect to the params:

def f(params):
  return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

result = jax.jacfwd(f)(params)

If that doesn't help, I'd suggest editing your question to make clear what operation you're interested in differentiating.

jakevdp
  • 77,104
  • 11
  • 125
  • 160