0

I am looking at Deepmind's implementation of a transformer using the Haiku neural network library.

I'm confused by their forward function:

def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
                     num_layers: int, dropout_rate: float):
  """Create the model's forward pass."""

  def forward_fn(data: Mapping[str, jnp.ndarray],
                 is_training: bool = True) -> jnp.ndarray:
    """Forward pass."""
    tokens = data['obs']
    input_mask = jnp.greater(tokens, 0)
    seq_length = tokens.shape[1]

    # Embed the input tokens and positions.
    embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
    token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
    token_embs = token_embedding_map(tokens)
    positional_embeddings = hk.get_parameter(
        'pos_embs', [seq_length, d_model], init=embed_init)
    input_embeddings = token_embs + positional_embeddings

    # Run the transformer over the inputs.
    transformer = model.Transformer(
        num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
    output_embeddings = transformer(input_embeddings, input_mask, is_training)

    # Reverse the embeddings (untied).
    return hk.Linear(vocab_size)(output_embeddings)

  return forward_fn

During gradient descent, this function is called every single time in order to calculate the loss. That's all expected.

The confusing thing is: each of the layers is constructed from scratch every time the function is called. For example: hk.Embed(...). Yet, somehow these layers are maintaining consistent identity.

How is Haiku keeping track of these layers. I'm wondering, because I'd like to get a mutable reference to these layers objects so I can--for example--print out the weights.

Foobar
  • 7,458
  • 16
  • 81
  • 161
  • Have you seen how many times this forward_fn function is called in the code? It is exactly once, because it builds a symbolic representation as most DL frameworks do. – Dr. Snoopy Jun 29 '22 at 06:19
  • It is called multiple times. See: https://github.com/deepmind/dm-haiku/blob/aefcfc40a1f6c551e82fe96ffbe4871ef39c1891/examples/transformer/train.py#L100 – Foobar Jun 29 '22 at 06:20
  • No, you are not tracing the function calls correctly, forward_fn is defined multiple times in the code (function and objects). – Dr. Snoopy Jun 29 '22 at 06:23
  • I'm quite sure this is not true? `lm_loss_fn` is transformed by a `functools.partial` and then called multiple times. – Foobar Jun 29 '22 at 18:59
  • @Dr.Snoopy Also, see this: https://dm-haiku.readthedocs.io/en/latest/notebooks/build_your_own_haiku.html – Foobar Jun 29 '22 at 18:59

0 Answers0