0

In each epoch I :

  1. prune the model
  2. optimize model
  3. remove prune mask and save model

However, I want to maintain exactly the same mask -> zero out the same weights at the beginning of each epoch.

My current code:

for module in model.modules():
    if isinstance(module, nn.Conv2d):
        torch.nn.utils.prune.l1_unstructured(module, 'weight', 0.3)

... optimizing ...

for module in model.modules():
    if isinstance(module, nn.Conv2d):
        torch.nn.utils.prune.remove(module, 'weight')

... saving ...

Is there a way to save the mask and reapply it? Then I would prune before training, save the mask, and just reapply it at the beginning of each epoch. Something like this:

mask = torch.nn.utils.prune.get_masks(model)
torch.save(mask, 'mask.pt')
...
mask = torch.load('mask.pt')
torch.nn.utils.prune.apply_mask(model, mask)
kaycaborr
  • 49
  • 1
  • 5

0 Answers0