1

I am training a model with efficientnet pytorch and to reduce overfitting, I want to prune some of the parameters.

My model is implemented as follows:


import torchvision.models as models
import torch.nn as nn

model = models.efficientnet_b0(pretrained=True)

for params in model.parameters():
    params.requires_grad = True

model.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes)

This tutorial for pruning a pytorch model suggests isolating a module and then pruning it as follows:

module = model.conv1 prune.random_unstructured(module, name="weight", amount=0.3)

However, there is no module named conv1 in the efficientnet model. I then listed the modules of the efficientnet model and got the following:

  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)

...

However, I do not know which of these are the module names, nor do I know how to isolate the module. How can I do so?

Am I on the right path for pruning my original model? Is there a different approach I should take? Thank you for the clarification and help.

Thai Pro
  • 11
  • 1

1 Answers1

1

To start with let's upload the model :

import torchvision.models as models
import torch.nn as nn
from torch.nn.utils import prune 

model = models.efficientnet_b0(pretrained=True)
num_classes=10
for params in model.parameters():
    params.requires_grad = True

model.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes)

Clear blueprints of the model :

Printing the model directly isn't a very clear way to get an overview of its different components. I personally prefer using torchinfo

eg :

from torchinfo import summary
summary(model, input_size=(1,3, 28, 28))

Will give you the following output :

=========================================================================================================
Layer (type:depth-idx)                                  Output Shape              Param #
=========================================================================================================
EfficientNet                                            [1, 10]                   --
├─Sequential: 1-1                                       [1, 1280, 1, 1]           --
│    └─Conv2dNormActivation: 2-1                        [1, 32, 14, 14]           --
│    │    └─Conv2d: 3-1                                 [1, 32, 14, 14]           864
│    │    └─BatchNorm2d: 3-2                            [1, 32, 14, 14]           64
│    │    └─SiLU: 3-3                                   [1, 32, 14, 14]           --
│    └─Sequential: 2-2                                  [1, 16, 14, 14]           --
│    │    └─MBConv: 3-4                                 [1, 16, 14, 14]           1,448
│    └─Sequential: 2-3                                  [1, 24, 7, 7]             --
│    │    └─MBConv: 3-5                                 [1, 24, 7, 7]             6,004
│    │    └─MBConv: 3-6                                 [1, 24, 7, 7]             10,710
│    └─Sequential: 2-4                                  [1, 40, 4, 4]             --
│    │    └─MBConv: 3-7                                 [1, 40, 4, 4]             15,350
│    │    └─MBConv: 3-8                                 [1, 40, 4, 4]             31,290
│    └─Sequential: 2-5                                  [1, 80, 2, 2]             --
│    │    └─MBConv: 3-9                                 [1, 80, 2, 2]             37,130
│    │    └─MBConv: 3-10                                [1, 80, 2, 2]             102,900
│    │    └─MBConv: 3-11                                [1, 80, 2, 2]             102,900
│    └─Sequential: 2-6                                  [1, 112, 2, 2]            --
│    │    └─MBConv: 3-12                                [1, 112, 2, 2]            126,004
│    │    └─MBConv: 3-13                                [1, 112, 2, 2]            208,572
│    │    └─MBConv: 3-14                                [1, 112, 2, 2]            208,572
│    └─Sequential: 2-7                                  [1, 192, 1, 1]            --
│    │    └─MBConv: 3-15                                [1, 192, 1, 1]            262,492
│    │    └─MBConv: 3-16                                [1, 192, 1, 1]            587,952
│    │    └─MBConv: 3-17                                [1, 192, 1, 1]            587,952
│    │    └─MBConv: 3-18                                [1, 192, 1, 1]            587,952
│    └─Sequential: 2-8                                  [1, 320, 1, 1]            --
│    │    └─MBConv: 3-19                                [1, 320, 1, 1]            717,232
│    └─Conv2dNormActivation: 2-9                        [1, 1280, 1, 1]           --
│    │    └─Conv2d: 3-20                                [1, 1280, 1, 1]           409,600
│    │    └─BatchNorm2d: 3-21                           [1, 1280, 1, 1]           2,560
│    │    └─SiLU: 3-22                                  [1, 1280, 1, 1]           --
├─AdaptiveAvgPool2d: 1-2                                [1, 1280, 1, 1]           --
├─Sequential: 1-3                                       [1, 10]                   --
│    └─Dropout: 2-10                                    [1, 1280]                 --
│    └─Linear: 2-11                                     [1, 10]                   12,810
=========================================================================================================
Total params: 4,020,358
Trainable params: 4,020,358
Non-trainable params: 0
Total mult-adds (M): 8.11
=========================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 1.97
Params size (MB): 16.08
Estimated Total Size (MB): 18.06
=========================================================================================================

We can see that our model is composed of 2 main modules :

  • Sequential: 1-1 ( this is the feature extractor which in turn is composed of multiple sub modules)
  • Sequential: 1-3 (this is your classifier)

To get the primary modules names you can use :

set([i.split(".")[0] for i in model.state_dict().keys()])

Which will output :

{'classifier', 'features'}

Prunning :

For prunning we're mostly interested in the weights and biases of the individual layers, not the abstracted submodules :

for name, param in model.named_parameters():
    print(name)

Will output :

features.0.0.weight
features.0.1.weight
features.0.1.bias
features.1.0.block.0.0.weight
features.1.0.block.0.1.weight
features.1.0.block.0.1.bias
...

features.8.1.weight
features.8.1.bias
classifier.1.weight
classifier.1.bias

Now you can prune each layer you want as specified in the documentation :

module_to_prune =  model.features[0][0]
prune.random_unstructured(module_to_prune, name="weight", amount=0.3)

Let's check if the specified weights got pruned, and if there's a new _orig in the named_parameters.

for name, param in model.named_parameters():
    print(name)

output :

features.0.0.weight_orig
features.0.1.weight
features.0.1.bias
features.1.0.block.0.0.weight
features.1.0.block.0.1.weight
  • Thank you very much! The pruning didn't do much to help my test accuracy, but I assume I'm not doing it properly. However, this code was very helpful! – Thai Pro May 11 '23 at 17:04
  • @ThaiPro I don't think you are using the correct approach - pruning is generally used to improve inference throughput, not accuracy. – simeonovich May 13 '23 at 04:19
  • This sort of prunning is unlikely to improve accuracy, and it's better to utilize dropout/batch normalization or other forms of regularization during training instead. And this also won't change the inference time, since you're not dynamically changing the computation graph. It might work if your gpu supports sparse matrix multiplication, but even then your matrix has to follow certain conditions. see [here](https://developer.nvidia.com/blog/accelerating-inference-with-sparsity-using-ampere-and-tensorrt/) – Laassairi Abdellah May 13 '23 at 11:10