I am looking at Deepmind's implementation of a transformer using the Haiku neural network library.
I'm confused by their forward function:
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
num_layers: int, dropout_rate: float):
"""Create the model's forward pass."""
def forward_fn(data: Mapping[str, jnp.ndarray],
is_training: bool = True) -> jnp.ndarray:
"""Forward pass."""
tokens = data['obs']
input_mask = jnp.greater(tokens, 0)
seq_length = tokens.shape[1]
# Embed the input tokens and positions.
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
token_embs = token_embedding_map(tokens)
positional_embeddings = hk.get_parameter(
'pos_embs', [seq_length, d_model], init=embed_init)
input_embeddings = token_embs + positional_embeddings
# Run the transformer over the inputs.
transformer = model.Transformer(
num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
output_embeddings = transformer(input_embeddings, input_mask, is_training)
# Reverse the embeddings (untied).
return hk.Linear(vocab_size)(output_embeddings)
return forward_fn
During gradient descent, this function is called every single time in order to calculate the loss. That's all expected.
The confusing thing is: each of the layers is constructed from scratch every time the function is called. For example: hk.Embed(...)
. Yet, somehow these layers are maintaining consistent identity.
How is Haiku keeping track of these layers. I'm wondering, because I'd like to get a mutable reference to these layers objects so I can--for example--print out the weights.