I am trying to use a transformer to analyze some spatio-temporal data. I have an array of training data with dimensions "batch size x sequence length x spatial samples x embedding dimension." In order to prevent the transformer from cheating while training, I want to make an attention mask that will prevent data from future timesteps from attending to timesteps in the past. This code snippet on its own reproduces the error.
import numpy as np
import tensorflow as tf
from tensorflow import keras
class TemporalMaskedTransformerBlock(keras.layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
super(TemporalMaskedTransformerBlock, self).__init__()
self.att1 = keras.layers.MultiHeadAttention(num_heads=num_heads,
key_dim=embed_dim)
self.ffn = keras.Sequential(
[keras.layers.Dense(ff_dim, activation="gelu"),
keras.layers.Dense(embed_dim),]
)
self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = keras.layers.Dropout(rate)
self.dropout2 = keras.layers.Dropout(rate)
def causal_attention_mask(self, batch_size, seq_len, embed_dim, dtype):
mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
print(" -- mask =", np.shape(mask))
mask = 1 - mask # Invert the mask to keep future information
print(" -- mask =", np.shape(mask))
mask = tf.expand_dims(mask, 0)
print(" -- mask =", np.shape(mask))
mask = tf.expand_dims(mask, -1)
print(" -- mask =", np.shape(mask))
mask = tf.tile(mask, [batch_size, 1, 1, embed_dim])
print(" -- mask =", np.shape(mask))
return mask
def call(self, inputs, training):
input_shape = tf.shape(inputs)
batch_size = input_shape[0]
seq_len = input_shape[1]
spat_samples = input_shape[2]
embed_dim = input_shape[3]
mask = self.causal_attention_mask(batch_size, seq_len, embed_dim, inputs.dtype)
print("inputs =", np.shape(inputs), ", mask =", np.shape(mask))
attn_output = self.att1(inputs, inputs, attention_mask=mask)
print("attn =", np.shape(attn_output))
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
batch_size = 100
seq_len = 20
spatial_samples = 100
embed_dim = 10
ff_dim = 100
input_data = tf.constant(np.random.rand(batch_size, seq_len, spatial_samples, embed_dim))
input_layer = keras.layers.Input(shape=(input_data.shape[1:]), name='Input')
''' Put data into the transformer '''
transformer = TemporalMaskedTransformerBlock(embed_dim=input_layer.shape[-1],
num_heads=4,
ff_dim=ff_dim)(input_layer)
model = keras.Model(inputs_layer, transformer, name="Test")
This is the error produced when I run it:
-- mask = (20, 20)
-- mask = (20, 20)
-- mask = (1, 20, 20)
-- mask = (1, 20, 20, 1)
-- mask = (None, 20, 20, 10)
inputs = (None, 20, 100, 10) , mask = (None, 20, 20, 10)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In [18], line 2
1 ''' Put embedded data into the transformer '''
----> 2 transformer = TemporalMaskedTransformerBlock(embed_dim=input_layer.shape[-1],
3 num_heads=4,
4 ff_dim=ff_dim)(input_layer)
6 model = keras.Model(inputs_layer, transformer, name="Test")
7 keras.utils.plot_model(model, show_shapes=True)
File ~/tensorflow-test/env/lib/python3.10/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
67 filtered_tb = _process_traceback_frames(e.__traceback__)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
File /var/folders/pr/x8x6s_v91y98r86gkdvp9pz40000gn/T/__autograph_generated_filetqbf4k08.py:17, in outer_factory.<locals>.inner_factory.<locals>.tf__call(self, inputs, training)
15 mask = ag__.converted_call(ag__.ld(self).causal_attention_mask, (ag__.ld(batch_size), ag__.ld(seq_len), ag__.ld(embed_dim), ag__.ld(inputs).dtype), None, fscope)
16 ag__.ld(print)('inputs =', ag__.converted_call(ag__.ld(np).shape, (ag__.ld(inputs),), None, fscope), ', mask =', ag__.converted_call(ag__.ld(np).shape, (ag__.ld(mask),), None, fscope))
---> 17 attn_output = ag__.converted_call(ag__.ld(self).att1, (ag__.ld(inputs), ag__.ld(inputs)), dict(attention_mask=ag__.ld(mask)), fscope)
18 ag__.ld(print)('attn =', ag__.converted_call(ag__.ld(np).shape, (ag__.ld(attn_output),), None, fscope))
19 attn_output = ag__.converted_call(ag__.ld(self).dropout1, (ag__.ld(attn_output),), dict(training=ag__.ld(training)), fscope)
ValueError: Exception encountered when calling layer "temporal_masked_transformer_block_2" (type TemporalMaskedTransformerBlock).
in user code:
File "/var/folders/pr/x8x6s_v91y98r86gkdvp9pz40000gn/T/ipykernel_21498/2208681912.py", line 44, in call *
attn_output = self.att1(inputs, inputs, attention_mask=mask)
File "/Users/joshuamiller/tensorflow-test/env/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler **
raise e.with_traceback(filtered_tb) from None
ValueError: Exception encountered when calling layer "softmax" " f"(type Softmax).
Dimensions must be equal, but are 100 and 20 for '{{node temporal_masked_transformer_block_2/multi_head_attention_2/softmax/add}} = AddV2[T=DT_FLOAT](temporal_masked_transformer_block_2/multi_head_attention_2/einsum/Einsum, temporal_masked_transformer_block_2/multi_head_attention_2/softmax/mul)' with input shapes: [?,4,20,100,20,100], [1,1,?,20,20,10].
Call arguments received by layer "softmax" " f"(type Softmax):
• inputs=tf.Tensor(shape=(None, 4, 20, 100, 20, 100), dtype=float32)
• mask=tf.Tensor(shape=(1, 1, None, 20, 20, 10), dtype=float32)
Call arguments received by layer "temporal_masked_transformer_block_2" (type TemporalMaskedTransformerBlock):
• inputs=tf.Tensor(shape=(None, 20, 100, 10), dtype=float32)
• training=None