After a model is pruned in Pytorch, the saved model contains both the pruned weights and weight_orig. This causes the pruned model size to be greater than the unpruned model. Is there a way to remove the weight_orig and reduce the pruned model size?
Asked
Active
Viewed 1,049 times
1 Answers
0
As explained in the offcial documentation, you can use torch.nn.utils.prune.remove()
for this very purpose.
remove()
removes the re-parametrization in terms of weight_orig
and weight_mask
, and removes the forward_pre_hook
.
You'd use it like this:
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
prune.remove(module,'weight')
# etc...

Hossein
- 24,202
- 35
- 119
- 224