1

Problem Statement: I have an image and a pixel of the image can belong to only(either) one of Band5','Band6', 'Band7' (see below for details). Hence, I have a pytorch multi-class problem but I am unable to understand how to set the targets which needs to be in form [batch, w, h]

My dataloader return two values:

x = chips.loc[:, :, :, self.input_bands]     
y = chips.loc[:, :, :, self.output_bands]        
x = x.transpose('chip','channel','x','y')
y_ohe = y.transpose('chip','channel','x','y')

Also, I have defined:

input_bands = ['Band1','Band2', 'Band3', 'Band3', 'Band4']  # input classes
output_bands = ['Band5','Band6', 'Band7'] #target classes

model = ModelName(num_classes = 3, depth=default_depth, in_channels=5, merge_mode='concat').to(device)
loss_new = nn.CrossEntropyLoss()

In my training function:

        #get values from dataloader
        X = normalize_zero_to_one(X) #input
        y = normalize_zero_to_one(y) #target

        images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
        masks = Variable(torch.from_numpy(y)).to(device) 
        optim.zero_grad()        
        outputs = model(images) 

        loss = loss_new(outputs, masks) # (preds, target)
        loss.backward()         
        optim.step() # Update weights  

I know the the target (here masks) should be [batch_size, w, h]. However, it is currently [batch_size, channels, w, h].

I read a lot of posts including 1, 2 and they say the target should only contain the target class indices. I don't understand how can I concatenate indices of three classes and still set target as [batch_size, w, h].

Right now, I get the error:

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

To the best of my understanding, I don't need to do any one hot encoding. Similar errors and explanation I found on the internet are here:'

Any help will be appreciated! Thank you.

Sulphur
  • 514
  • 6
  • 24

1 Answers1

0

If I understand correctly, your current "target" is [batch_size, channels, w, h] with channels==3 as you have three possible targets.
What are the values in your target represent? You basically have a 3-vector target for each pixel - are these the expected class probabilities? Are they "one-hot-vectors" indicating the correct "band"? If so, you can get the target indices by simply taking the argmax along the target channel dimension:

proper_target = torch.argmax(masks, dim=1)  # make sure keepdim=False
loss = loss_new(outputs, proper_target)
Shai
  • 111,146
  • 38
  • 238
  • 371