I try to plot attention maps for ViT. I know that I can do something like
h_attn = model.blocks[-1].attn.register_forward_hook(get_activations('attention'))
to register a hook that camputres output of some nn.module
in my model.
The ViT's attention layer has the following forward structure:
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
Can I somehow attach the hook such that i get the attn
value and not the return value of forward (e.g. by using some kind of dummy-module)?