What is a good practice not to lose hours/days of training a network if something broke in the middle?
Asked
Active
Viewed 33 times
2 Answers
2
I use a custom callback that stores the last epoch, weights, loss etc to resume afterwards:
class StatefulCheckpoint(ModelCheckpoint):
"""Save extra checkpoint data to resume training."""
def __init__(self, weight_file, state_file=None, **kwargs):
"""Save the state (epoch etc.) along side weights."""
super().__init__(weight_file, **kwargs)
self.state_f = state_file
self.state = dict()
if self.state_f:
# Load the last state if any
try:
with open(self.state_f, 'r') as f:
self.state = json.load(f)
self.best = self.state['best']
except Exception as e: # pylint: disable=broad-except
print("Skipping last state:", e)
def on_epoch_end(self, epoch, logs=None):
"""Saves training state as well as weights."""
super().on_epoch_end(epoch, logs)
if self.state_f:
state = {'epoch': epoch+1, 'best': self.best,
'hostname': self.hostname}
state.update(logs)
state.update(self.params)
with open(self.state_f, 'w') as f:
json.dump(state, f)
def get_last_epoch(self, initial_epoch=0):
"""Return last saved epoch if any, or return default argument."""
return self.state.get('epoch', initial_epoch)
This only works if you epochs are of reasonable time, ex. 1 hour but it is clean and consistent with the Keras API.

nuric
- 11,027
- 3
- 27
- 42
1
A simple solution is to use logging and to serialize models to disk(s) at regular intervals. You could keep the up to 5 versions of the network to avoid running out of disk memory.
Python has great logging utilities and you might find pickle useful to serialize your models.

shayaan
- 1,482
- 1
- 15
- 32