1

I am running Alexnet on CIFAR10 dataset using Pytorch Lightning, here is my model:

class SelfSupervisedModel(pl.LightningModule):
    def __init__(self, hparams=None, num_classes=10, batch_size=128):
        super(SelfSupervisedModel, self).__init__()

        self.batch_size = batch_size
        self.loss_fn = nn.CrossEntropyLoss()
        self.hparams["lr"] = ModelHelper.Hyperparam.Learning_rate

        self.model = torchvision.models.alexnet(pretrained=False)

    def forward(self, x):
        return self.model(x)

    def training_step(self, train_batch, batch_idx):
        inputs, targets = train_batch
        predictions = self(inputs)
        loss = self.loss_fn(predictions, targets)
        return {'loss': loss}

    def validation_step(self, test_batch, batch_idx):
        inputs, targets = test_batch
        predictions = self(inputs)
        val_loss = self.loss_fn(predictions, targets)
        _, preds = tf.max(predictions, 1)
        acc = tf.sum(preds == targets.data) / (targets.shape[0] * 1.0)
        return {'val_loss': val_loss, 'val_acc': acc, 'target': targets, 'preds': predictions}

    def validation_epoch_end(self, outputs):
        avg_loss = tf.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = tf.stack([x['val_acc'].float() for x in outputs]).mean()
        logs = {'val_loss': avg_loss, 'val_acc': avg_acc}
        print(f'validation_epoch_end logs => {logs}')

        OutputMatrix.predictions = tf.cat([tmp['preds'] for tmp in outputs])
        OutputMatrix.targets = tf.cat([tmp['target'] for tmp in outputs])
        
        return {'progress_bar': logs}

    def configure_optimizers(self):
      return tf.optim.SGD(self.parameters(), lr=self.hparams["lr"], momentum=0.9)

I am storing the predicted and true values in OutputMatrix.predictions and OutputMatrix.targets which are used to generate confusion matrix looks like below: enter image description here

I'm pretty much sure that this should not be the output though. Can not find where is the mistake. Any help would be appreciated.

nim_10
  • 477
  • 8
  • 21
  • Here is a solution using Torchvision's Matrics: https://stackoverflow.com/a/65628129/3731282 – nim_10 Oct 23 '21 at 23:13

1 Answers1

0

I would suggest using Torchmetrics and the internal log method, so the code could like:

class MyModule(LightningModule):

    def __init__(self):
        ...
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        self.train_acc(preds, y)
        self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.valid_acc(logits, y)
        self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)

as you can also find in the docs related to PL integration.

Jirka
  • 1,126
  • 6
  • 25