0

After a model is pruned in Pytorch, the saved model contains both the pruned weights and weight_orig. This causes the pruned model size to be greater than the unpruned model. Is there a way to remove the weight_orig and reduce the pruned model size?

Hossein
  • 24,202
  • 35
  • 119
  • 224
AcidBurn
  • 199
  • 1
  • 11

1 Answers1

0

As explained in the offcial documentation, you can use torch.nn.utils.prune.remove() for this very purpose.
remove() removes the re-parametrization in terms of weight_orig and weight_mask, and removes the forward_pre_hook. You'd use it like this:

for module in model.modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.remove(module,'weight')
    # etc...
Hossein
  • 24,202
  • 35
  • 119
  • 224