0

Could someone explain to me where I can find out what the output of pytorch lightning prediction tensors mean?

I have this code:

#Predicting
path = analysis.best_checkpoint + '/' + "ray_ckpt"

model = GraphLevelGNN.load_from_checkpoint(path)
model.eval() 

trainer = pl.Trainer()
test_result = trainer.test(model, graph_test_loader, verbose=False)

print(test_result)
##[{'test_acc': 0.65625, 'test_f1': 0.7678904428904428, 'test_precision': 1.0, 'test_recall': 0.65625}]

predictions = trainer.predict(model, graph_test_loader)

print(predictions)

And it prints:

[(tensor(0.7582), tensor(0.5000), 0.6666666666666666, 1.0, 0.5), (tensor(0.4276), tensor(0.7500), 0.8571428571428571, 1.0, 0.75), (tensor(0.4436), tensor(0.7500), 0.8571428571428571, 1.0, 0.75), (tensor(0.2545), tensor(1.), 1.0, 1.0, 1.0), (tensor(1.0004), tensor(0.3750), 0.5454545454545454, 1.0, 0.375)]

But I can't seem to understand what these numbers mean? Can someone explain how to get more info?

Slowat_Kela
  • 1,377
  • 2
  • 22
  • 60

1 Answers1

1

Well in a simple summary its the forward pass that we can define with a prediction step

import pytorch_lightning as pl

class LitModel(pl.LightningModule):
   def forward(self, inputs):
       return self.base_model(inputs)
   
   # Overwrite the predict step
   def predict_step(self, batch, batch_idx):
      return self(batch)

model = LitModel()
trainer = pl.Trainer()
trainer.predict(model, data) # note data is a dataloader

for a deeper explanation read this: output prediction of pytorch lightning model

Edwin Cheong
  • 879
  • 2
  • 7
  • 12
  • Thanks, that link doesn't explain what the fields in the output is right? And then when I click the link in the pytorch doc that might explain it, the link is broken? : https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#testing – Slowat_Kela Aug 23 '22 at 12:32
  • 1
    no because its the output of your model, check your models forward function and what does it return, it differs can by case – Edwin Cheong Aug 23 '22 at 12:34