0

How can I generate attention maps for 3D grayscale MRI data after training with vision transformer for a classification problem?

My data shape is (120,120,120) and the model is 3D ViT. For example:

img = nib.load()
img = torch.from_numpy(img)
model = torch.load...
model.eval()

output, attn = model(img)

After this, because I have 6 transformer layers and 12 heads, so the attn I got the shape that is

(6,12,65,65)

Then I don't know how to apply this to original 3D grayscale images. I got several examples online that only deal with images from ImageNet.

For example:

https://github.com/tczhangzhi/VisionTransformer-Pytorch

https://github.com/jacobgil/vit-explain

Can anyone help me with this?

Shai
  • 111,146
  • 38
  • 238
  • 371
Panda
  • 1

1 Answers1

0

I would guess your ViT splits your volumes to 4x4x4 tokens and adds a single cls token; overall 65 tokens per volume.

If you want to see how the cls token attends to all other 64 tokens for a specific layer and a specific head, you can:

import matplotlib.pyplot as plt

layer = 4  # check the 5th layer
head = 7  # check the 7th head
cls_attn = attn[layer, head, 0, 1:].reshape(4, 4, 4)
fig, ax = plt.subplots(2, 2)
for z in range(cls_attn.shape[0]):
  ax.flat[z].matshow(cls_attn[z, ...])
plt.show()
Shai
  • 111,146
  • 38
  • 238
  • 371