2

Assuming that I have a MLP that uses ReLU as activation function and CrossEntropyLoss as loss function to classify samples with 3 features that are part of one of 10 classes: How would I implement that? The target values are given as numbers from 0 to 9. When using CrossEntropyLoss the target values have to be simple numbers instead one hot vectors. But when trying to convert the results of the MLP into a single number I get an index error.

The standard implementation of the MLP:

class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__() 
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.output_size = output_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
        self.softmax = torch.nn.Softmax()
    
    def forward(self, x):
        hidden = self.fc1(x)
        relu = self.relu(hidden)
        output = self.fc2(relu)
        output = self.softmax(output)
        return output

As well as the execution that gives me an error:

mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
for epoch in range(epoch):
    y_pred = mlp_model(x_train)
    y_scalar = torch.argmax(y_pred, dim=1)

    loss = criterion(y_scalar, y_train) <-------------- error

    loss.backward()
mlp_model.eval()
y_pred = mlp_model(x_test)
y_scalar = torch.argmax(y_pred, dim=1)
test_loss = criterion(y_scalar, y_test) 
print('Test loss after Training' , test_loss.item())

y_pred_list = y_pred.tolist()
y_test_list = y_test.tolist()

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_test_list, y_pred_list)

The error: IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

Output of y_scalar and y_train:

tensor([1, 3, 3, 3, 1, 1, 1, 3, 3, 1, 3, 1, 1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3,
        1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3,
        3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 1,
        1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3,
        3, 1, 3, 1, 3, 3, 3, 1, 1, 1, 3, 1, 1, 3, 3, 1, 1, 1, 1, 3, 3, 1, 3, 3,
        1, 3, 1, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 3, 1, 3, 1, 1, 3, 3, 1, 1, 1,
        1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 1,
        1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 1, 3, 3,
        3, 3, 3, 3, 3, 1, 3, 1, 1, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
        3, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3,
        1, 3, 1, 3, 1, 3, 3, 3, 1, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1,
        1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3,
        3, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
        3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 1, 1, 3, 3, 3, 1, 3,
        1, 1, 1, 3, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3, 3, 1,
        3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 1, 1, 3, 1,
        3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 3, 1,
        1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 1, 1, 3, 3, 1,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 1, 3, 1,
        3, 1, 3, 1, 1, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1, 3, 3, 3, 3, 3, 3, 1, 1, 3,
        1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 1, 3, 3,
        1, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1,
        3, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3,
        1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 1, 3, 3, 3, 3,
        3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 1,
        3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 3, 3, 1, 1,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3])
tensor([3., 4., 4., 0., 3., 2., 0., 3., 3., 2., 0., 0., 4., 3., 3., 3., 2., 3.,
        1., 3., 5., 3., 4., 6., 3., 3., 6., 3., 2., 4., 3., 6., 0., 4., 2., 0.,
        1., 5., 4., 4., 3., 6., 6., 4., 3., 3., 2., 5., 3., 4., 5., 3., 0., 2.,
        1., 4., 6., 3., 2., 2., 0., 0., 0., 4., 2., 0., 4., 5., 2., 6., 5., 2.,
        2., 2., 0., 4., 5., 6., 4., 0., 0., 0., 4., 2., 4., 1., 4., 6., 0., 4.,
        2., 4., 6., 6., 0., 0., 6., 5., 0., 6., 0., 2., 1., 1., 1., 2., 6., 5.,
        6., 1., 2., 2., 1., 5., 5., 5., 6., 5., 6., 5., 5., 1., 6., 6., 1., 5.,
        1., 6., 5., 5., 5., 1., 5., 1., 1., 1., 1., 1., 1., 1., 4., 3., 0., 3.,
        6., 6., 0., 3., 4., 0., 3., 4., 4., 1., 2., 2., 2., 3., 3., 3., 3., 0.,
        4., 5., 0., 3., 4., 3., 3., 3., 2., 3., 3., 2., 2., 6., 1., 4., 3., 3.,
        3., 6., 3., 3., 3., 3., 0., 4., 2., 2., 6., 5., 3., 5., 4., 0., 4., 3.,
        4., 4., 3., 3., 2., 4., 0., 3., 2., 3., 3., 4., 4., 0., 3., 6., 0., 3.,
        3., 4., 3., 3., 5., 2., 3., 2., 4., 1., 3., 2., 2., 3., 3., 3., 3., 5.,
        1., 3., 1., 3., 5., 0., 3., 5., 0., 4., 2., 4., 2., 4., 4., 5., 4., 3.,
        5., 3., 3., 4., 3., 0., 4., 5., 0., 3., 6., 2., 5., 5., 5., 3., 2., 3.,
        0., 4., 5., 3., 0., 4., 0., 3., 3., 0., 0., 3., 5., 4., 4., 3., 4., 3.,
        3., 2., 2., 3., 0., 3., 1., 3., 2., 3., 3., 4., 5., 2., 1., 1., 0., 0.,
        1., 6., 1., 3., 3., 3., 2., 3., 3., 0., 3., 4., 1., 3., 4., 3., 2., 0.,
        0., 4., 2., 3., 2., 1., 4., 6., 3., 2., 0., 3., 3., 2., 3., 4., 4., 2.,
        1., 3., 5., 3., 2., 0., 4., 5., 1., 3., 3., 2., 0., 2., 4., 2., 2., 2.,
        5., 4., 4., 2., 2., 0., 3., 2., 4., 4., 5., 5., 1., 0., 3., 4., 5., 3.,
        4., 5., 3., 4., 3., 3., 1., 4., 3., 3., 5., 2., 3., 2., 5., 5., 4., 3.,
        3., 3., 3., 1., 5., 3., 3., 2., 6., 0., 1., 3., 0., 1., 5., 3., 6., 3.,
        6., 0., 3., 3., 3., 5., 4., 3., 4., 0., 5., 2., 1., 2., 4., 4., 4., 4.,
        3., 3., 0., 4., 3., 0., 5., 2., 0., 5., 4., 4., 4., 3., 0., 6., 5., 2.,
        4., 5., 1., 3., 5., 3., 0., 3., 5., 1., 1., 0., 3., 4., 2., 6., 2., 0.,
        5., 3., 4., 6., 5., 3., 5., 0., 1., 3., 0., 5., 2., 2., 3., 5., 1., 0.,
        3., 1., 4., 2., 5., 6., 4., 2., 2., 6., 0., 0., 4., 6., 3., 2., 0., 3.,
        6., 1., 6., 3., 1., 3., 3., 3., 3., 2., 5., 4., 5., 5., 3., 1., 3., 3.,
        4., 4., 2., 0., 2., 0., 5., 4., 0., 0., 3., 2., 2., 2., 2., 6., 4., 6.,
        5., 5., 1., 0., 0., 4., 3., 3., 1., 3., 6., 6., 2., 3., 3., 3., 1., 2.,
        2., 5., 4., 3., 2., 1., 2., 2., 3., 2., 3., 2., 3., 3., 0., 5., 3., 3.,
        3., 4., 5., 3., 2., 1., 4., 4., 4., 4., 0., 5., 4., 1., 3., 0., 3., 4.,
        6., 3., 6., 3., 3., 3., 6., 3., 4., 3., 6., 3., 0., 3., 1., 2., 5., 6.,
        5., 2., 0., 2., 2., 3., 3., 0., 3., 5., 3., 4., 0., 3., 2., 4., 5., 2.,
        3., 2., 2., 3., 5., 2., 0., 3., 4., 3.])```
MenNotAtWork
  • 145
  • 2
  • 8

1 Answers1

0

As mentioned in the comment softmax is not required inside the model as nn.CrossEntropyLoss includes it. Also, calculation of the loss is done before argmax. Note also the shapes of input and outputs to the model. Please refer the following updates.

import torch
class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__() 
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.output_size = output_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
        #self.softmax = torch.nn.Softmax()
    
    def forward(self, x):
        hidden = self.fc1(x)
        relu = self.relu(hidden)
        output = self.fc2(relu)
        #output = self.softmax(output)
        return output

mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
x_train = torch.randn(100, 3) # random 100 inputs of shape (100, 3)
y_train = torch.randint(low=0, high=10, size=(100,)) # random 100 ground truths of shape (100,)
for epoch in range(epoch):
    y_pred = mlp_model(x_train)
    y_scalar = torch.argmax(y_pred, dim=1)

    #loss = criterion(y_scalar, y_train)# <-------------- error
    loss = criterion(y_pred, y_train) # loss calculated before argmax

    loss.backward().....