9

This is potentially a very easy question. I just started with PyTorch lightning and can't figure out how to receive the output of my model after training.

I am interested in both predictions of y_train and y_test as an array of some sort (PyTorch tensor or NumPy array in a later step) to plot next to the labels using different scripts.

dataset = Dataset(train_tensor)
val_dataset = Dataset(val_tensor)
training_generator = torch.utils.data.DataLoader(dataset, **train_params)
val_generator = torch.utils.data.DataLoader(val_dataset, **val_params)
mynet = Net(feature_len)
trainer = pl.Trainer(gpus=0,max_epochs=max_epochs, logger=logger, progress_bar_refresh_rate=20, callbacks=[early_stop_callback], num_sanity_val_steps=0)
trainer.fit(mynet)

In my lightning module I have the functions:

def __init__(self, random_inputs):

def forward(self, x):

def train_dataloader(self):
    
def val_dataloader(self):

def training_step(self, batch, batch_nb):

def training_epoch_end(self, outputs):

def validation_step(self, batch, batch_nb):

def validation_epoch_end(self, outputs):

def configure_optimizers(self):

Do I need a specific predict function or is there any already implemented way I don't see?

Tom S
  • 591
  • 1
  • 5
  • 21

4 Answers4

10

I disagree with these answers: OP's question appears to be focused on how he should use a model trained in lightning to get predictions in general, rather than for a specific step in the training pipeline. In which case, a user shouldn't need to go anywhere near a Trainer object - those are not intended to be used for general prediction and the answers above are therefore encouraging an anti-pattern (carrying a trainer object around with us every time we want to do some prediction) to anyone who reads these answers in the future.

Instead of using trainer, we can get predictions straight from the Lightning module that has been defined: if I have my (trained) instance of the lightning module model = Net(...) then using that model to get predictions on inputs x is achieved simply by calling model(x) (so long as the forward method has been implemented/overriden on the Lightning module - which is required).

In contrast, Trainer.predict() is not the intended means of obtaining predictions using your trained model in general. The Trainer API provides methods to tune, fit and test your LightningModule as part of your training pipeline, and it looks to me that the predict method is provided for ad-hoc predictions on separate dataloaders as part of less 'standard' training steps.

The OP's question (Do I need a specific predict function or is there any already implemented way I don't see?) implies that they're not familiar with the way that the forward() method works in PyTorch, but asks whether there's already a method for prediction that they can't see. A full answer therefore requires a further explanation of where the forward() method fits into the prediction process:

The reason model(x) works is because Lightning Modules are subclasses of torch.nn.Module and these implement a magic method called __call__() which means that we can call the class instance as if it were a function. __call__() in turn calls forward(), which is why we need to override that method in our Lightning module.

NB. because forward is only one piece of the logic called when we use model(x), it is always recommended to use model(x) instead of model.forward(x) for prediction unless you have a specific reason to deviate.

ericmjl
  • 13,541
  • 12
  • 51
  • 80
UpstatePedro
  • 410
  • 4
  • 10
  • 3
    It's good that you pointed out how a network can be run directly since when starting with Pytorch Lightning without ever having used Pytorch directly hides the underlying mechanisms. I would argue that it's still reasonable in situations to use the Trainer class even for prediction, as it handles putting your model and data onto the GPU, it can call certain hooks, why reinvent the wheel? It's not an antipattern, rename the class to `Commander` and much of your argument is invalid. I still think it's good you pointed it out, but antipattern is too strong. – Florian Blume Jan 21 '22 at 14:18
  • I think advice on how to get predictions from the model needs to include how to run it on a gpu, model.eval(), turning off gradients and all the other things that Lightning has done for the user so far. Simply calling model(x) is unlikely to do what the user wants. – Thomas Ahle Apr 07 '22 at 20:04
4

You can use the predict method as well. Here is the example from the document. https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html

class LitMNISTDreamer(LightningModule):

    def forward(self, z):
        imgs = self.decoder(z)
        return imgs

    def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None):
        return self(batch)


model = LitMNISTDreamer()
trainer.predict(model, datamodule) 
sushmit
  • 4,369
  • 2
  • 35
  • 38
  • 2
    The predict method seems to have been added in the meantime. I was just baffled it wasnt available before. – Tom S Jun 22 '21 at 15:23
  • Yeah they seem crazy good at adding new stuff – Adrien Forbu Jul 25 '21 at 12:59
  • What's the difference between using `trainer.predict()` and using `model()`? Does the first option automatically wrap the call inside eval mode and no_grad? – Michael Aug 02 '21 at 11:01
  • The trainer puts your model and input on the graphics card, limits the number of batches (if set, see trainer `__init__` args), performs distributed computation and so on. – Florian Blume Jan 21 '22 at 14:13
  • Is there anyway to run "predict" as an iterator? I don't really want to load all of my data into memory – Thomas Ahle Apr 07 '22 at 20:05
4

You can try prediction in two ways:

  1. Perform batched prediction as per normal.
test_dataset = Dataset(test_tensor)
test_generator = torch.utils.data.DataLoader(test_dataset, **test_params)

mynet.eval()
batch = next(iter(test_generator))
with torch.no_grad():
    predictions_single_batch = mynet(**unpacked_batch)
  1. Instantiate a new Trainer object. Trainer's predict API allows you to pass an arbitrary DataLoader.
test_dataset = Dataset(test_tensor)
test_generator = torch.utils.data.DataLoader(test_dataset, **test_params)

predictor = pl.Trainer(gpus=1)
predictions_all_batches = predictor.predict(mynet, dataloaders=test_generator)

 I've noticed that in the second case, Pytorch Lightning takes care of stuff like moving your tensors and model onto (not off of) GPU, aligned with its potential to perform distributed predictions. It also doesn't returns any gradient-attached loss values, which helps dispense of the need to write boilerplate code like with torch.no_grad().

user11717481
  • 1
  • 9
  • 15
  • 25
Ying Jiang
  • 41
  • 1
  • An important point to this answer is that you need to create a new trainer in some circumstances upon testing/prediction. the [documentation for predict](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#predict) explains accelerators that spawn new processes won't return predictions (so they will not sync if you want to gather them later) e.g., under DDP. So you can train under DDP but cannot do inference under DDP as it's not supported. – davzaman Aug 26 '22 at 22:27
  • I haven't tested this, but my understanding for this statement ```True by default except when an accelerator that spawns processes is used (not supported).``` is that `return_prediction` is not supported if we set up the `Trainer` with `ddp_spawn` instead of `ddp`. There might be some complications or bottleneck with `mp.spawn()`. I do agree that setting up a predictor with `Trainer` is semantically quite confusing. – Ying Jiang May 22 '23 at 17:39
0

The trainer has a test function. You might want to have a look at the original documents from pytorch-lightning for more details: https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#testing.

SJ11
  • 244
  • 1
  • 3
  • Awesome. how did I not find this myself. Most likely because of all the errors I got. But sorted everything out. – Tom S Jan 20 '21 at 15:41
  • 2
    Seems to have a predict function now: https://github.com/PyTorchLightning/pytorch-lightning/issues/1853 – Georg Heiler May 20 '21 at 19:17
  • I don't believe `.test` allows you to return a tensor (It's purpose is largely to collects logs via the `logging` API - which don't currently accept lists or torch/np.arrays). So .`predict()` appears to be the way forward. – hkh Jan 27 '22 at 09:53