0

I have a patch_tensor with the shape: torch.Size([2, 77, 256]) and I want to Unpatchify this to (N,H,W,C) or (N,C,H,W). The original shape of image is (2,4,64,64).

For patch embedding, I am using the PatchEmbed from timm library:

hidden_size = 36 / in_channels = 4 / patch_size = 8 / input_size = 64
from timm.models.vision_transformer import PatchEmbed
PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 

The encoder block of transformer is similar to the vanilla transformer.

So after extracting patches using PatchEmbed, the shape of the patched vector is : torch.Size([2, 64, 36]) .

This vector then goes through the Transformer's Encoder and the resulting vector shape is :

torch.Size([2, 77, 256])

From this point, I am not sure how to start unpatch this vector. Can you please help me.

Thank you in advance.

Preetom Saha Arko
  • 2,588
  • 4
  • 21
  • 37
Jessica
  • 1
  • 1

0 Answers0