I have a patch_tensor with the shape: torch.Size([2, 77, 256])
and I want to Unpatchify this to (N,H,W,C)
or (N,C,H,W)
. The original shape of image is (2,4,64,64).
For patch embedding, I am using the PatchEmbed
from timm
library:
hidden_size = 36 / in_channels = 4 / patch_size = 8 / input_size = 64
from timm.models.vision_transformer import PatchEmbed
PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
The encoder block of transformer is similar to the vanilla transformer.
So after extracting patches using PatchEmbed
, the shape of the patched vector is : torch.Size([2, 64, 36])
.
This vector then goes through the Transformer's Encoder and the resulting vector shape is :
torch.Size([2, 77, 256])
From this point, I am not sure how to start unpatch this vector. Can you please help me.
Thank you in advance.