For example in jax.experimental.stax
there is an Dense layer implemented like this:
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
"""Layer constructor function for a dense (fully-connected) layer."""
def init_fun(rng, input_shape):
output_shape = input_shape[:-1] + (out_dim,)
k1, k2 = random.split(rng)
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
return output_shape, (W, b)
def apply_fun(params, inputs, **kwargs):
W, b = params
return jnp.dot(inputs, W) + b
return init_fun, apply_fun
If we implement bias as being allowed to be None for example, or params having length 1, there are implications for how grad works.
What is the pattern that one should aim for here? jax.jit
has a static_argnums
that I suppose could be used with some has_bias
param but book-keeping this is involved and I am sure there must be some examples somewhere.