1

I'm training a model and I'm trying to add a confusion matrix, which would be displayed in my wandb, but I got lost a bit. Basically, the matrix works; I can print it, but it's not loaded into wandb. Everything should be OK, except it's not. Can you please help me? I'm new to all this. Thanks a lot!

the code

    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  
            else:
                model.eval()   

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                from sklearn.metrics import f1_score
                f1_score = f1_score(labels.cpu().data, preds.cpu(), average=None)
                wandb.log({'F1 score' : f1_score})

                nb_classes = 7

                confusion_matrix = torch.zeros(nb_classes, nb_classes)
                with torch.no_grad():
                    for i, (inputs, classes) in enumerate(dataloaders['val']):
                        inputs = inputs.to(device)
                        classes = classes.to(device)
                        outputs = model_ft(inputs)
                        _, preds = torch.max(outputs, 1)
                    
                    for t, p in zip(classes.view(-1), preds.view(-1)):
                        confusion_matrix[t.long(), p.long()] += 1
              wandb.log({'matrix' : confusion_matrix})
                           
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            wandb.log({'epoch loss': epoch_loss,
                    'epoch acc': epoch_acc})
            
            data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
            table = wandb.Table(data=data, columns=["step", "height"])
            wandb.log({'line-plot1': wandb.plot.line(table, "step", "height")})

        
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc, f1_score))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()
        
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    print('f1_score: {}'.format(f1_score))
   
    model.load_state_dict(best_model_wts)
    return model
toyota Supra
  • 3,181
  • 4
  • 15
  • 19

1 Answers1

2

Have you tried the wandb Confusion matrix that comes with wandb?

cm = wandb.plot.confusion_matrix(
    y_true=ground_truth,
    preds=predictions,
    class_names=class_names)
    
wandb.log({"conf_mat": cm})
morganmcg
  • 460
  • 2
  • 5
  • Hi, thanks a lot. Yes, ofc, I did, but I couldn't define it, tbh, it always returned an error. So I figured out, it could be easier to define a confusion matrix and then plot it into wandb. – nothingisundercontrol Aug 11 '22 at 15:47
  • what error did it return? it might be related to the data type you are passing confusion_matrix. Try passing it numpy arrays, or even just lists – morganmcg Aug 12 '22 at 14:50
  • Various, the majority was mostly not defined, so that's definitely my fault. I thought about lists too, but I want to compute the matrix without lists. I'm having really hard times defining it. :) – nothingisundercontrol Aug 17 '22 at 12:31