I'm trying to implement and train a neural network using the JAX library and its little neural network submodule, "Stax". Since this library doesn't come with an implementation of binary cross entropy, I wrote my own:
def binary_cross_entropy(y_hat, y):
bce = y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat)
return jnp.mean(-bce)
I implemented a simple neural network and trained it on MNIST, and started to get suspicious of some of the results I was getting. So I implemented the same setup in Keras, and I immediately got wildly different results! The same model, trained in the same way on the same data, was getting 90% training accuracy in Keras instead of around 50% in JAX. Eventually I tracked down part of the issue to my naive implementation of cross-entropy, which is supposedly numerically unstable. Following this post and this code I found, I wrote the following new version:
def binary_cross_entropy_stable(y_hat, y):
y_hat = jnp.clip(y_hat, 0.000001, 0.9999999)
logits = jnp.log(y_hat/(1 - y_hat))
max_logit = jnp.clip(logits, 0, None)
bces = logits - logits * y + max_logit + jnp.log(jnp.exp(-max_logit) + jnp.exp(-logits - max_logit))
return jnp.mean(bces)
This works a little better. Now my JAX implementation gets up to 80% train accuracy, but that's still a lot less than the 90% Keras gets. What I want to know is what is going on? Why are my two implementations not behaving the same way?
Below, I condensed my two implementations down to a single script. In this script, I implement the same model in JAX and in Keras. I initialize both with the same weights, and train them using full-batch gradient descent for 10 steps on 1000 datapoints from MNIST, the same data for each model. JAX finishes with 80% training accuracy, while Keras finishes with 90%. Specifically, I get this output:
Initial Keras accuracy: 0.4350000023841858
Initial JAX accuracy: 0.435
Final JAX accuracy: 0.792
Final Keras accuracy: 0.9089999794960022
JAX accuracy (Keras weights): 0.909
Keras accuracy (JAX weights): 0.7919999957084656
And actually, when I vary the conditions a little (using different random initial weights or a different training set), sometimes I get back the 50% JAX accuracy and 90% Keras accuracy.
I swap the weights at the end to verify that the weights obtained from training are indeed the issue, not something to do with the actual computation of the network predictions, or the way I calculate accuracy.
The code:
import numpy as np
import jax
from jax import jit, grad
from jax.experimental import stax, optimizers
import jax.numpy as jnp
import keras
import keras.datasets.mnist
def binary_cross_entropy(y_hat, y):
bce = y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat)
return jnp.mean(-bce)
def binary_cross_entropy_stable(y_hat, y):
y_hat = jnp.clip(y_hat, 0.000001, 0.9999999)
logits = jnp.log(y_hat/(1 - y_hat))
max_logit = jnp.clip(logits, 0, None)
bces = logits - logits * y + max_logit + jnp.log(jnp.exp(-max_logit) + jnp.exp(-logits - max_logit))
return jnp.mean(bces)
def binary_accuracy(y_hat, y):
return jnp.mean((y_hat >= 1/2) == (y >= 1/2))
########################################
# #
# Create dataset #
# #
########################################
input_dimension = 784
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data(path="mnist.npz")
xs = np.concatenate([x_train, x_test])
xs = xs.reshape((70000, 784))
ys = np.concatenate([y_train, y_test])
ys = (ys >= 5).astype(np.float32)
ys = ys.reshape((70000, 1))
train_xs = xs[:1000]
train_ys = ys[:1000]
########################################
# #
# Create JAX model #
# #
########################################
jax_initializer, jax_model = stax.serial(
stax.Dense(1000),
stax.Relu,
stax.Dense(1),
stax.Sigmoid
)
rng_key = jax.random.PRNGKey(0)
_, initial_jax_weights = jax_initializer(rng_key, (1, input_dimension))
########################################
# #
# Create Keras model #
# #
########################################
initial_keras_weights = [*initial_jax_weights[0], *initial_jax_weights[2]]
keras_model = keras.Sequential([
keras.layers.Dense(1000, activation="relu"),
keras.layers.Dense(1, activation="sigmoid")
])
keras_model.compile(
optimizer=keras.optimizers.SGD(learning_rate=0.01),
loss=keras.losses.binary_crossentropy,
metrics=["accuracy"]
)
keras_model.build(input_shape=(1, input_dimension))
keras_model.set_weights(initial_keras_weights)
if __name__ == "__main__":
########################################
# #
# Compare untrained models #
# #
########################################
initial_keras_predictions = keras_model.predict(train_xs, verbose=0)
initial_jax_predictions = jax_model(initial_jax_weights, train_xs)
_, keras_initial_accuracy = keras_model.evaluate(train_xs, train_ys, verbose=0)
jax_initial_accuracy = binary_accuracy(jax_model(initial_jax_weights, train_xs), train_ys)
print("Initial Keras accuracy:", keras_initial_accuracy)
print("Initial JAX accuracy:", jax_initial_accuracy)
########################################
# #
# Train JAX model #
# #
########################################
L = jit(binary_cross_entropy_stable)
gradL = jit(grad(lambda w, x, y: L(jax_model(w, x), y)))
opt_init, opt_apply, get_params = optimizers.sgd(0.01)
network_state = opt_init(initial_jax_weights)
for _ in range(10):
wT = get_params(network_state)
gradient = gradL(wT, train_xs, train_ys)
network_state = opt_apply(
0,
gradient,
network_state
)
final_jax_weights = get_params(network_state)
final_jax_training_predictions = jax_model(final_jax_weights, train_xs)
final_jax_accuracy = binary_accuracy(final_jax_training_predictions, train_ys)
print("Final JAX accuracy:", final_jax_accuracy)
########################################
# #
# Train Keras model #
# #
########################################
for _ in range(10):
keras_model.fit(
train_xs,
train_ys,
epochs=1,
batch_size=1000,
verbose=0
)
final_keras_loss, final_keras_accuracy = keras_model.evaluate(train_xs, train_ys, verbose=0)
print("Final Keras accuracy:", final_keras_accuracy)
########################################
# #
# Swap weights #
# #
########################################
final_keras_weights = keras_model.get_weights()
final_keras_weights_in_jax_format = [
(final_keras_weights[0], final_keras_weights[1]),
tuple(),
(final_keras_weights[2], final_keras_weights[3]),
tuple()
]
jax_accuracy_with_keras_weights = binary_accuracy(
jax_model(final_keras_weights_in_jax_format, train_xs),
train_ys
)
print("JAX accuracy (Keras weights):", jax_accuracy_with_keras_weights)
final_jax_weights_in_keras_format = [*final_jax_weights[0], *final_jax_weights[2]]
keras_model.set_weights(final_jax_weights_in_keras_format)
_, keras_accuracy_with_jax_weights = keras_model.evaluate(train_xs, train_ys, verbose=0)
print("Keras accuracy (JAX weights):", keras_accuracy_with_jax_weights)
Try changing the PRNG seed at line 57 to a value other than 0
to run the experiment using different initial weights.