I am trying to create an FID to measure the performance of my generative models on MNIST.
I provide my own feature extractor.
However, in order to find the output dimension of the feature extractor you provide, torchmetrics tries to pass it a dummy image to see what dimension it outputs.
The problems is that the dummy image they generate does not follow the shape or date type my feature extractor expects.
There is no way for me to manually specifiy the dummy image that should be passed in, so I can't control that.
Here is an example of what I'm trying to do:
N = <appropriate number>
class SimpleConvFeatureExtractor(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.conv = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=2)
self.out = nn.Sequential(nn.Linear(N, embed_dim))
def forward(self, x):
return th.randn(size=(1, 128))
print(x.shape)
print(x.dtype)
x = F.silu(self.conv1(x))
x = self.out(x.view(x.shape[0], -1))
return x
fid = FrechetInceptionDistance(feature=SimpleConvFeatureExtractor(128))
with output
torch.Size([1, 3, 299, 299]) torch.uint8 RuntimeError: Input type (unsigned char) and bias type (float) should be the same
As you can see the image being passed through is hardly an MNIST image.