I need to build a MLP in Jax, but I get slightly different (and in my opinion inaccurate) result from Jax respect to a MLP created in Tensorflow.
In both cases I created a dataset where the y are linear function of X plus a standard gaussian error, the dataset is the same in both cases.
I initialized the MLP in tensorflow with the same initialization I did in Jax (to be sure to start with the exact same network).
In Tensorflow I fit the network using this:
model.compile(loss=tf.keras.losses.mean_squared_error,optimizer=tf.keras.optimizers.SGD(learning_rate = 0.00001))
model.fit(X, y, batch_size = X.shape[0], epochs = 5000)
And this is what I get (it seems correct):
Now, in Jax i train the network as follows:
loss = lambda params, x, y: jnp.mean((apply_fn(params, x) - y) ** 2)
@jit
def update(params, x, y, learning_rate):
grad_loss = grad(loss)(params, x, y)
# SGD update
return jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params, grad_loss # for every leaf i.e. for every param of MLP
)
learning_rate = 0.00001
num_epochs = 5000
for _ in range(num_epochs):
params = update(params, X, y, learning_rate)
This is what I get as result:
I notice that if I increase a lot the number of epochs in the Jax implementation it works better (the model predictions get closer and closer to the real values) but how can I get a similar result from Jax to Tensorflow without increasing the number of epochs?