1

I’m training a U-Net (model below) to predict 4 heatmaps (gaussian centered around a keypoint, one in each channel). Each channel is for some reason outputting the same result, an example is given of a test image where the blue is ground truth for that channel and red is the output of the u-net. I have tried using L1, MSE and adaptive wing loss (Wang 2019), and none are able to regress the heatmaps. I'm not sure what I'm doing wrong would appreciate any advice. Thanks

test1:

test1

test2:

test2

test3:

test3

test4:

test4

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64))
    
        self.layer2 = nn.Sequential( 
            nn.Conv2d(64, 64,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64))
        
        self.layer3 = nn.Sequential(
            nn.MaxPool2d(2, stride=2, padding=0))
            
        self.layer4 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128))
    
        self.layer5 = nn.Sequential(
            nn.Conv2d(128, 128,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128))
            
        self.layer6 = nn.Sequential(
            nn.MaxPool2d(2, stride=2, padding=0))
            
        self.layer7 = nn.Sequential(
            nn.Conv2d(128, 256,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(256)) 
Christoph Rackwitz
  • 11,317
  • 4
  • 27
  • 36
solace
  • 11
  • 1

1 Answers1

1

Hard to see how this is a UNet. The only components that will modify the spatial shape of your input is your MaxPool2d's. You have two of these, so for a given input with size [B, 1, H, W], your output will have the shape [B, 256, H/4, W/4].

I think you need to give a more complete code snippet (don't have enough rep to leave this as a comment).

  • it looks somewhat like *the first half* of a U-net. -- a three-month-old question is probably no longer relevant to the one who asked. if can be worth engaging if the question appears to have some general value. you too seem to think that this question has no general value because OP constructed their own network that is apparently non-standard and not working properly. – Christoph Rackwitz Aug 01 '22 at 10:53