0

I am trying to use a transformer to analyze some spatio-temporal data. I have an array of training data with dimensions "batch size x sequence length x spatial samples x embedding dimension." In order to prevent the transformer from cheating while training, I want to make an attention mask that will prevent data from future timesteps from attending to timesteps in the past. This code snippet on its own reproduces the error.

import numpy as np
import tensorflow as tf
from tensorflow import keras

class TemporalMaskedTransformerBlock(keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TemporalMaskedTransformerBlock, self).__init__()
        self.att1 = keras.layers.MultiHeadAttention(num_heads=num_heads,
                                                    key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [keras.layers.Dense(ff_dim, activation="gelu"),
             keras.layers.Dense(embed_dim),]
        )
        self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = keras.layers.Dropout(rate)
        self.dropout2 = keras.layers.Dropout(rate)
        
    def causal_attention_mask(self, batch_size, seq_len, embed_dim, dtype):
        mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
        print(" -- mask =", np.shape(mask))
        mask = 1 - mask  # Invert the mask to keep future information
        print(" -- mask =", np.shape(mask))
        mask = tf.expand_dims(mask, 0)
        print(" -- mask =", np.shape(mask))
        mask = tf.expand_dims(mask, -1)
        print(" -- mask =", np.shape(mask))
        mask = tf.tile(mask, [batch_size, 1, 1, embed_dim])
        print(" -- mask =", np.shape(mask))
        return mask

    def call(self, inputs, training):
        input_shape = tf.shape(inputs)
        
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        spat_samples = input_shape[2]
        embed_dim = input_shape[3]
        
        mask = self.causal_attention_mask(batch_size, seq_len, embed_dim, inputs.dtype)
        
        print("inputs =", np.shape(inputs), ", mask =", np.shape(mask))
        
        attn_output = self.att1(inputs, inputs, attention_mask=mask)
        
        print("attn =", np.shape(attn_output))
        
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

batch_size = 100
seq_len = 20
spatial_samples = 100
embed_dim = 10

ff_dim = 100


input_data = tf.constant(np.random.rand(batch_size, seq_len, spatial_samples, embed_dim))

input_layer = keras.layers.Input(shape=(input_data.shape[1:]), name='Input')

''' Put data into the transformer '''
transformer = TemporalMaskedTransformerBlock(embed_dim=input_layer.shape[-1],
                                             num_heads=4,
                                             ff_dim=ff_dim)(input_layer)

model = keras.Model(inputs_layer, transformer, name="Test")

This is the error produced when I run it:

 -- mask = (20, 20)
 -- mask = (20, 20)
 -- mask = (1, 20, 20)
 -- mask = (1, 20, 20, 1)
 -- mask = (None, 20, 20, 10)
inputs = (None, 20, 100, 10) , mask = (None, 20, 20, 10)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [18], line 2
      1 ''' Put embedded data into the transformer '''
----> 2 transformer = TemporalMaskedTransformerBlock(embed_dim=input_layer.shape[-1],
      3                                              num_heads=4,
      4                                              ff_dim=ff_dim)(input_layer)
      6 model = keras.Model(inputs_layer, transformer, name="Test")
      7 keras.utils.plot_model(model, show_shapes=True)

File ~/tensorflow-test/env/lib/python3.10/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /var/folders/pr/x8x6s_v91y98r86gkdvp9pz40000gn/T/__autograph_generated_filetqbf4k08.py:17, in outer_factory.<locals>.inner_factory.<locals>.tf__call(self, inputs, training)
     15 mask = ag__.converted_call(ag__.ld(self).causal_attention_mask, (ag__.ld(batch_size), ag__.ld(seq_len), ag__.ld(embed_dim), ag__.ld(inputs).dtype), None, fscope)
     16 ag__.ld(print)('inputs =', ag__.converted_call(ag__.ld(np).shape, (ag__.ld(inputs),), None, fscope), ', mask =', ag__.converted_call(ag__.ld(np).shape, (ag__.ld(mask),), None, fscope))
---> 17 attn_output = ag__.converted_call(ag__.ld(self).att1, (ag__.ld(inputs), ag__.ld(inputs)), dict(attention_mask=ag__.ld(mask)), fscope)
     18 ag__.ld(print)('attn =', ag__.converted_call(ag__.ld(np).shape, (ag__.ld(attn_output),), None, fscope))
     19 attn_output = ag__.converted_call(ag__.ld(self).dropout1, (ag__.ld(attn_output),), dict(training=ag__.ld(training)), fscope)

ValueError: Exception encountered when calling layer "temporal_masked_transformer_block_2" (type TemporalMaskedTransformerBlock).

in user code:

    File "/var/folders/pr/x8x6s_v91y98r86gkdvp9pz40000gn/T/ipykernel_21498/2208681912.py", line 44, in call  *
        attn_output = self.att1(inputs, inputs, attention_mask=mask)
    File "/Users/joshuamiller/tensorflow-test/env/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None

    ValueError: Exception encountered when calling layer "softmax" "                 f"(type Softmax).
    
    Dimensions must be equal, but are 100 and 20 for '{{node temporal_masked_transformer_block_2/multi_head_attention_2/softmax/add}} = AddV2[T=DT_FLOAT](temporal_masked_transformer_block_2/multi_head_attention_2/einsum/Einsum, temporal_masked_transformer_block_2/multi_head_attention_2/softmax/mul)' with input shapes: [?,4,20,100,20,100], [1,1,?,20,20,10].
    
    Call arguments received by layer "softmax" "                 f"(type Softmax):
      • inputs=tf.Tensor(shape=(None, 4, 20, 100, 20, 100), dtype=float32)
      • mask=tf.Tensor(shape=(1, 1, None, 20, 20, 10), dtype=float32)


Call arguments received by layer "temporal_masked_transformer_block_2" (type TemporalMaskedTransformerBlock):
  • inputs=tf.Tensor(shape=(None, 20, 100, 10), dtype=float32)
  • training=None

0 Answers0