I'm making a transformer using tensorflow.keras
and having issues understanding how the attention_mask
works for a MultiHeadAttention
layer.
My input is 3-dimensional data. For example, let's assume my whole dataset has 10 elements, each one with length no more than 4:
# whole data
[
# first item
[
[ 1, 2, 3],
[ 1, 2, 3],
[np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan],
],
# second item
[
[ 1, 2, 3],
[ 5, 8, 2],
[ 3, 7, 8],
[ 4, 6, 2],
],
... # 8 more items
]
So, my mask looks like:
# assume this is a numpy array
mask = [
[
[1, 1, 1],
[1, 1, 1],
[0, 0, 0],
[0, 0, 0],
],
[
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
],
...
]
So the shape of the mask til now is [10, 4, 3]
. Let's say I use batch_size = 5
. Now, according documentation, attention_mask
shape should be [B, T, S]
(batch_size, query_size, key_size). In the example case should be [5, 4, 4]
?
Question
If the mask is calculated only once, what 5 items should I give as a mask? This sounds counterintuitive to me. How should I build the mask?
According this answer, head_size should be also taken in account, so they also do:
mask = mask[:, tf.newaxis, tf.newaxis, :]
What I've tested
The only time I manage to run the transformer successfully using the attention_mask
is when I do:
mask = np.ones((batch_size, data.shape[1], data.shape[2]))
mask = mask[:, tf.newaxis, tf.newaxis, :]
Obviously that mask makes no sense, because it is all ones, but it was just to test if it had the correct shape.
The model
I'm using practically the same code from the keras
example transformer for time series classification
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0.0, mask=None):
# Normalization and Attention
x = layers.LayerNormalization(epsilon=1e-6)(inputs)
x = layers.MultiHeadAttention(
key_dim=head_size, num_heads=num_heads, dropout=dropout
)(x, x, attention_mask=mask)
x = layers.Dropout(dropout)(x)
res = x + inputs
# Feed Forward Part
x = layers.LayerNormalization(epsilon=1e-6)(res)
x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(x)
x = layers.Dropout(dropout)(x)
x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
return x + res
def build_model(
n_classes,
input_shape,
head_size,
num_heads,
ff_dim,
num_transformer_blocks,
mlp_units,
dropout=0.0,
mlp_dropout=0.0,
input_mask=None,
) -> keras.Model:
inputs = keras.Input(shape=input_shape)
x = inputs
for _ in range(num_transformer_blocks):
x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout, input_mask)
x = layers.GlobalAveragePooling2D(data_format="channels_first")(x)
for dim in mlp_units:
x = layers.Dense(dim, activation="relu")(x)
x = layers.Dropout(mlp_dropout)(x)
outputs = layers.Dense(n_classes, activation="softmax")(x)
return keras.Model(inputs, outputs)