1

I am applying pruning using pytorch's torch.nn.utils.prune on a model with LSTM layers. However, when I save the contents of the state_dict, the model is much larger than before pruning. I'm not sure why, as if I print out the sizes of the elements of the state_dict before and after pruning, everything is the same dimension, and there are no additional elements in the state_dict.

My code for pruning is pretty standard, and I make sure to call prune.remove():

        model_state = model.state_dict()
        torch.save(model.state_dict(), 'pre_pruning.pth')
        for param_tensor in model_state:
            print(param_tensor, "\t", model_state[param_tensor].size())

        parameters_to_prune = []
        for param, _ in model.rnn.named_parameters():
            if "weight" in param:
                parameters_to_prune.append((model.rnn, param))
        prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.6)
        for module, param in parameters_to_prune:
            prune.remove(module, param)

        model_state = model.state_dict()
        torch.save(model_state, 'pruned.pth') # This file is much larger than the original
        for param_tensor in model_state:
            print(param_tensor, "\t", model_state[param_tensor].size())

When I have tried to prune the linear layers in the model, the saved model does not show the same increase in size as when I prune the LSTM layers. Any idea what could be causing this?

iampotato
  • 11
  • 2

1 Answers1

0

This is because pruning introduces two new parameters weight_orig and weight_mask which essentially increases the model size.

If you wish to remove those parameters, you should use torch.nn.utils.prune.remove() to remove them and stops the pruning process (meaning the zeroed weights will not be frozen after this step).

Check Pruning Tutorial for details.