0
def compute_loss(ouputs, y, criterion):
 """Compute the loss given the ouput, ground truth y, and the criterion function.

 Hint: criterion should be cross entropy.

 Hint: the input is a list of tensors, each has shape as (batch_size, vocab_size)

 Hint: you need concat the tensors, and reshape its size to (batch_size*num_step, vocab_size).
    Then compute the loss with y. 
 Returns:
    output: A 0-d tensor--averaged loss (scalar)
 """ 
 ### YOUR CODE HERE
 #print("output in compute_loss length: ", len(ouputs))
 print("initial single output shape in compute_loss length: ", ouputs[0].shape)
 for element in range(len(ouputs)):
   if element == 0:
     output = ouputs[element]
   else:
     output = torch.cat((output,ouputs[element]),0)
 output = torch.reshape(output,(y.size(dim=0)*len(ouputs),ouputs[0].size(dim=1)))
 print("y in compute_loss: ", y.shape)
 print("Final output shape:",output.shape)
 print("\n",output[0].shape)
 print("criterion in compute_loss: ", criterion(output,y))

 ### END YOUR CODE
 return loss

I am getting the output as below but with an error:

initial single output shape in compute_loss length:  torch.Size([64, 10000])
y in compute_loss:  torch.Size([64, 10])
Final output shape: torch.Size([640, 10000])

The error is :

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3012     if size_average is not None or reduce is not None:
   3013         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3014     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   3015 
   3016 

ValueError: Expected input batch_size (640) to match target batch_size (64).

I can see the final output shape is torch.Size([640, 10000]) and I am sending that to criterion(output,y) but it is still asking for a batch of 640 whereas the output shape is torch.Size([640, 10000])

Ethereal soul
  • 709
  • 1
  • 7
  • 20

0 Answers0