In each epoch I :
- prune the model
- optimize model
- remove prune mask and save model
However, I want to maintain exactly the same mask -> zero out the same weights at the beginning of each epoch.
My current code:
for module in model.modules():
if isinstance(module, nn.Conv2d):
torch.nn.utils.prune.l1_unstructured(module, 'weight', 0.3)
... optimizing ...
for module in model.modules():
if isinstance(module, nn.Conv2d):
torch.nn.utils.prune.remove(module, 'weight')
... saving ...
Is there a way to save the mask and reapply it? Then I would prune before training, save the mask, and just reapply it at the beginning of each epoch. Something like this:
mask = torch.nn.utils.prune.get_masks(model)
torch.save(mask, 'mask.pt')
...
mask = torch.load('mask.pt')
torch.nn.utils.prune.apply_mask(model, mask)