I am applying pruning using pytorch's torch.nn.utils.prune
on a model with LSTM layers. However, when I save the contents of the state_dict
, the model is much larger than before pruning. I'm not sure why, as if I print out the sizes of the elements of the state_dict
before and after pruning, everything is the same dimension, and there are no additional elements in the state_dict
.
My code for pruning is pretty standard, and I make sure to call prune.remove()
:
model_state = model.state_dict()
torch.save(model.state_dict(), 'pre_pruning.pth')
for param_tensor in model_state:
print(param_tensor, "\t", model_state[param_tensor].size())
parameters_to_prune = []
for param, _ in model.rnn.named_parameters():
if "weight" in param:
parameters_to_prune.append((model.rnn, param))
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.6)
for module, param in parameters_to_prune:
prune.remove(module, param)
model_state = model.state_dict()
torch.save(model_state, 'pruned.pth') # This file is much larger than the original
for param_tensor in model_state:
print(param_tensor, "\t", model_state[param_tensor].size())
When I have tried to prune the linear layers in the model, the saved model does not show the same increase in size as when I prune the LSTM layers. Any idea what could be causing this?