0

I am implementing a modified vision transformer based on the Github implementation. The author has also published a YouTube video explaining the implementation. But this implementation doesn't have any provision to incorporate src_key_padding_mask. (The built-in transformer encoder accepts this as a parameter) I know that I have to perform some mathematical operation using this mask in the forward method of the Attention module. The mask contains True where there is a padding token and False elsewhere.

If I just place dp = dp @ mask just after the dot product of query and key, will it serve the purpose of src_key_padding_mask used in the built-in version?

dp = (q @ k_t) * self.scale # (n_samples, n_heads, n_patches + 1, n_patches + 1)
dp = dp @ mask
Christoph Rackwitz
  • 11,317
  • 4
  • 27
  • 36
Preetom Saha Arko
  • 2,588
  • 4
  • 21
  • 37

0 Answers0