-1

For example in jax.experimental.stax there is an Dense layer implemented like this:

def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
  """Layer constructor function for a dense (fully-connected) layer."""
  def init_fun(rng, input_shape):
    output_shape = input_shape[:-1] + (out_dim,)
    k1, k2 = random.split(rng)
    W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,))
    return output_shape, (W, b)
  def apply_fun(params, inputs, **kwargs):
    W, b = params
    return jnp.dot(inputs, W) + b
  return init_fun, apply_fun

If we implement bias as being allowed to be None for example, or params having length 1, there are implications for how grad works.

What is the pattern that one should aim for here? jax.jit has a static_argnums that I suppose could be used with some has_bias param but book-keeping this is involved and I am sure there must be some examples somewhere.

joel
  • 6,359
  • 2
  • 30
  • 55
mathtick
  • 6,487
  • 13
  • 56
  • 101

1 Answers1

1

Wouldn't this work ?

def Dense(out_dim, W_init=glorot_normal(), b_init=normal()):
  """Layer constructor function for a dense (fully-connected) layer."""

  def init_fun(rng, input_shape):
    output_shape = input_shape[:-1] + (out_dim,)
    k1, k2 = random.split(rng)
    W = W_init(k1, (input_shape[-1], out_dim))
    if b_init:
        b = b_init(k2, (out_dim,)
        return output_shape, (W, b)
    return output_shape, W

  def apply_fun(params, inputs, **kwargs):
    if len(params) == 1:
        W = params
        return jnp.dot(inputs, W)
    else:
        W, b = params
        return jnp.dot(inputs, W) + b

  return init_fun, apply_fun
Robin
  • 1,531
  • 1
  • 15
  • 35