0

I need to build a MLP in Jax, but I get slightly different (and in my opinion inaccurate) result from Jax respect to a MLP created in Tensorflow.

In both cases I created a dataset where the y are linear function of X plus a standard gaussian error, the dataset is the same in both cases.

I initialized the MLP in tensorflow with the same initialization I did in Jax (to be sure to start with the exact same network).

In Tensorflow I fit the network using this:

model.compile(loss=tf.keras.losses.mean_squared_error,optimizer=tf.keras.optimizers.SGD(learning_rate = 0.00001))
model.fit(X, y, batch_size = X.shape[0], epochs = 5000)

And this is what I get (it seems correct):

Tensorflow Result

Now, in Jax i train the network as follows:

loss = lambda params, x, y: jnp.mean((apply_fn(params, x) - y) ** 2)
@jit
def update(params, x, y, learning_rate):
    grad_loss = grad(loss)(params, x, y)
    # SGD update
    return jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grad_loss  # for every leaf i.e. for every param of MLP
    )

learning_rate = 0.00001
num_epochs = 5000
for _ in range(num_epochs):
    params = update(params, X, y, learning_rate)  
    

This is what I get as result:

Jax Result

I notice that if I increase a lot the number of epochs in the Jax implementation it works better (the model predictions get closer and closer to the real values) but how can I get a similar result from Jax to Tensorflow without increasing the number of epochs?

fabianod
  • 501
  • 4
  • 17
  • It's hard to say for sure what's going on without code to reproduce the issue, but I suspect this comes from JAX doing the optimization in float32 precision, while TF is doing the optimization in float64 precision. See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision. If that's not the issue, then you may have better luck getting a helpful answer if you include a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example). – jakevdp Mar 20 '23 at 12:19
  • The problem is that it's kind hard to provide a minimal reproducible example because I need to give you 2 examples, one for TF and one for Jax. Anyway, the update procedure I implemented in Jax is correct, right? I mean, tensorflow doesn't do anything more than that, right? Because I saw that the SGD class in keras has many parameters, I left them at their default values ​​hoping they didn't do some kind of optimization (which explains why with the same number of iterations in Tensorflow you get a better fit) – fabianod Mar 20 '23 at 13:39
  • I honestly don’t know. Did you try 64-bit precision in JAX? – jakevdp Mar 20 '23 at 13:46
  • It took me a while because I'm not using pure Jax, but a library built on top of Jax, so I had to implement the MLP in pure Jax to be able to initialize the weights with vectors in float64, nothing has changed compared to float32, so that's not the problem. – fabianod Mar 20 '23 at 15:33
  • My comment was not about changing the input values, it was about enabling float64 computation globally. See the link I shared in my first comment. If you don't do that, then even float64 inputs will lead to float32 computations. That idea is really a stab in the dark and may or may not work: it's hard to guess much of anything about code that I can't see or run :) – jakevdp Mar 20 '23 at 17:21
  • I saw the link you shared, I enabled float64 computation and to be sure that all is working fine I implemented the MLP in pure Jax. Nothing changed. – fabianod Mar 22 '23 at 12:20
  • Ok, well best of luck then. If you’re able to share complete code, I think you’d be likely to get better answers. – jakevdp Mar 22 '23 at 12:50
  • If you want to look at the code you can find it in this [repository](https://github.com/FabianoVeglianti/temporaryCodeRepository) – fabianod Mar 22 '23 at 14:40
  • 1
    Before running any SGD step, your TF loss function returns `3512` and your JAX loss funciton returns `3599`. This suggests that the two models are not equivalent. I would start by looking at how the models are defined and figuring out how they're different. – jakevdp Mar 22 '23 at 21:45
  • Good point, thank you a lot! I will investigate in that direction! – fabianod Mar 23 '23 at 22:21

0 Answers0