0

I have a ray tune analysis object and I am able to get the best checkpoint from it:

analysis = tune_robert_asha(num_samples=2)
best_ckpt = analysis.best_checkpoint

But I am unable to restore my pytorch lightning model with it.

I try:

MyLightningModel.load_from_checkpoint(
    os.path.join(analysis.best_checkpoint, "checkpoint")
)

But MyLightningModel exposes a config in its constructor so that ray tune can set certain hyperparameters for each trial:

class MyLightningModel (pl.LightningModule):
    def __init__(self, config=None):
        self.lr = config["lr"]
        self.batch_size = config["batch_size"]
        self.layer_size = config["layer_size"]

        super(MyLightningModel , self).__init__()
        self.lstm = nn.LSTM(768, self.layer_size, num_layers=1, bidirectional=False)
        self.out = nn.Linear(self.layer_size, 1)

Thus when I try to run load_from_checkpoint, I get an error in the constructor of MyLightningModel since the config is undefined:


TypeError Traceback (most recent call last) in () 1 MyLightningModel.load_from_checkpoint( ----> 2 os.path.join(analysis.best_checkpoint, "checkpoint") 3 )

2 frames in init(self, config) 3 def init(self, config=None): 4 ----> 5 self.lr = config["lr"] 6 self.batch_size = config["batch_size"] 7 self.layer_size = config["layer_size"]

TypeError: 'NoneType' object is not subscriptable

How is this supposed to be handled?

Luca Guarro
  • 1,085
  • 1
  • 11
  • 25
  • I also asked this question on the ray tune forum site but since it can be hard to get responses there, I wanted to ask here as well – Luca Guarro Oct 21 '21 at 16:34
  • This is being discussed here https://discuss.ray.io/t/how-to-properly-restore-checkpoint-when-using-pytorch-lightning/3895 – Amog Kamsetty Oct 21 '21 at 17:02

1 Answers1

0

You need to override the default config value which in your case is None

# constructor takes in an argument
MyLightningModel(config)

# uses the default argument for config which is None
model = LitModel.load_from_checkpoint(PATH)

# override the default parameter
model = LitModel.load_from_checkpoint(PATH, config='path/to/config')

Source: https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html#initialize-with-other-parameters

Mike B
  • 2,136
  • 2
  • 12
  • 31