0

I want to use the Crossentropyloss of pytorch but somehow my code only works with batchsize 2, so i am asuming there is something wrong with the shapes of target and output. I get following error:

Value Error: Expected target size (50, 2), got torch.Size([50, 3])

My targetsize is (N=50,batchsize=3) and the output of my model is (N=50, batchsize=3, number of classes =2). Before the output layer my shape is (N=50,batchsize=3,dimensions=64).

How do i need to change the shapes so that the Crossentropyloss works?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
Thomas231
  • 45
  • 1
  • 7
  • I have posted an answer, but it seems it will not suit your needs. Could you share with us a little more about the model? Is it an RNN? Could you share some code with us? – Ivan Jan 10 '21 at 12:08
  • Off course. it is a transformer model for knowledge tracing called SAINT. So N is the sequence length of a user answering questions. And for each question you want to predict if the user will answer right or wrong. – Thomas231 Jan 10 '21 at 12:08
  • You will need to provide some code, your `nn.Transformer` setup (if you do use this module). We can't guess what the issue is from here. – Ivan Jan 10 '21 at 12:11
  • I did not use the nn.Transformer set up. Which part do you need? The entire model? The train_epoch function? – Thomas231 Jan 10 '21 at 12:15

1 Answers1

2

Without further information about your model, here's what I would do. You have a many-to-many RNN which outputs (seq_len, batch_size, nb_classes) and the target is (seq_len, seq_len). The nn.CrossEntropyLoss module can take additional dimensions (batch_size, nb_classes, d1​, d2​, ..., dK​) as an input.

You could make it work by permuting the axes, such that the outputted tensor is of shape (batch_size, nb_classes, seq_len). This should make it happen:

output = output.permute(0, 2, 1)

Additionally, your target will also have to change to be (batch_size, seq_len):

target = target.permute(1, 0)
Ivan
  • 34,531
  • 8
  • 55
  • 100
  • No N is the number of values i want to predict. In this case it is the sequence length of a user answering questions. And for each question you want to make a prediction if a user will answers right or wrong. The model is a transformer model by the way. – Thomas231 Jan 10 '21 at 12:04