I've have the following doubt about Jax. I'll use an example from the official optax docs to illustrate it:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
print(f'step {i}, loss: {loss_value}')
return params
# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)
In this example, the function step
uses the variable optimizer
despite it not being passed within the function arguments (since the function is being jitted and optax.GradientTransformation
is not a supported type). However, the same function uses other variables that are instead passed as parameters (i.e., params, opt_state, batch, labels
). I understand that jax functions needs to be pure in order to be jitted, but what about input (read-only) variables. Is there any difference if I access a variable by passing it through the function arguments or if I access it directly since it's in the step
function scope? What if this variable is not constant but modified between separate step
calls? Are they treated like static arguments if accessed directly? Or are they simply jitted away and so modifications of such parameters will not be considered?
To be more specific, let's look at the following example:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_learning_rate # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels)
extra_learning_rate = 0.01 # does this affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels)
return params
vs
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
extra_learning_rate = 0.1
@jax.jit
def step(params, opt_state, batch, labels, extra_lr):
loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
updates, opt_state = optimizer.update(grads, opt_state, params)
updates *= extra_lr # not really valid code, but you get the idea
params = optax.apply_updates(params, updates)
return params, opt_state, loss_value
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
extra_learning_rate = 0.1
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
extra_learning_rate = 0.01 # does this now affect the next `step` call?
params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
return params
From my limited experiments, they perform differently as the second step
call doesn't uses the new learning rates in the global case and also no 're-jitting' happens, however I'd like to know if there's any standard practice/rules I need to be aware of. I'm writing a library where performance is fundamental and I don't want to miss some jit optimizations because I'm doing things wrong.