To start with let's upload the model :
import torchvision.models as models
import torch.nn as nn
from torch.nn.utils import prune
model = models.efficientnet_b0(pretrained=True)
num_classes=10
for params in model.parameters():
params.requires_grad = True
model.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes)
Clear blueprints of the model :
Printing the model directly isn't a very clear way to get an overview of its different components. I personally prefer using torchinfo
eg :
from torchinfo import summary
summary(model, input_size=(1,3, 28, 28))
Will give you the following output :
=========================================================================================================
Layer (type:depth-idx) Output Shape Param #
=========================================================================================================
EfficientNet [1, 10] --
├─Sequential: 1-1 [1, 1280, 1, 1] --
│ └─Conv2dNormActivation: 2-1 [1, 32, 14, 14] --
│ │ └─Conv2d: 3-1 [1, 32, 14, 14] 864
│ │ └─BatchNorm2d: 3-2 [1, 32, 14, 14] 64
│ │ └─SiLU: 3-3 [1, 32, 14, 14] --
│ └─Sequential: 2-2 [1, 16, 14, 14] --
│ │ └─MBConv: 3-4 [1, 16, 14, 14] 1,448
│ └─Sequential: 2-3 [1, 24, 7, 7] --
│ │ └─MBConv: 3-5 [1, 24, 7, 7] 6,004
│ │ └─MBConv: 3-6 [1, 24, 7, 7] 10,710
│ └─Sequential: 2-4 [1, 40, 4, 4] --
│ │ └─MBConv: 3-7 [1, 40, 4, 4] 15,350
│ │ └─MBConv: 3-8 [1, 40, 4, 4] 31,290
│ └─Sequential: 2-5 [1, 80, 2, 2] --
│ │ └─MBConv: 3-9 [1, 80, 2, 2] 37,130
│ │ └─MBConv: 3-10 [1, 80, 2, 2] 102,900
│ │ └─MBConv: 3-11 [1, 80, 2, 2] 102,900
│ └─Sequential: 2-6 [1, 112, 2, 2] --
│ │ └─MBConv: 3-12 [1, 112, 2, 2] 126,004
│ │ └─MBConv: 3-13 [1, 112, 2, 2] 208,572
│ │ └─MBConv: 3-14 [1, 112, 2, 2] 208,572
│ └─Sequential: 2-7 [1, 192, 1, 1] --
│ │ └─MBConv: 3-15 [1, 192, 1, 1] 262,492
│ │ └─MBConv: 3-16 [1, 192, 1, 1] 587,952
│ │ └─MBConv: 3-17 [1, 192, 1, 1] 587,952
│ │ └─MBConv: 3-18 [1, 192, 1, 1] 587,952
│ └─Sequential: 2-8 [1, 320, 1, 1] --
│ │ └─MBConv: 3-19 [1, 320, 1, 1] 717,232
│ └─Conv2dNormActivation: 2-9 [1, 1280, 1, 1] --
│ │ └─Conv2d: 3-20 [1, 1280, 1, 1] 409,600
│ │ └─BatchNorm2d: 3-21 [1, 1280, 1, 1] 2,560
│ │ └─SiLU: 3-22 [1, 1280, 1, 1] --
├─AdaptiveAvgPool2d: 1-2 [1, 1280, 1, 1] --
├─Sequential: 1-3 [1, 10] --
│ └─Dropout: 2-10 [1, 1280] --
│ └─Linear: 2-11 [1, 10] 12,810
=========================================================================================================
Total params: 4,020,358
Trainable params: 4,020,358
Non-trainable params: 0
Total mult-adds (M): 8.11
=========================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 1.97
Params size (MB): 16.08
Estimated Total Size (MB): 18.06
=========================================================================================================
We can see that our model is composed of 2 main modules :
- Sequential: 1-1 ( this is the feature extractor which in turn is composed of multiple sub modules)
- Sequential: 1-3 (this is your classifier)
To get the primary modules names you can use :
set([i.split(".")[0] for i in model.state_dict().keys()])
Which will output :
{'classifier', 'features'}
Prunning :
For prunning we're mostly interested in the weights and biases of the individual layers, not the abstracted submodules :
for name, param in model.named_parameters():
print(name)
Will output :
features.0.0.weight
features.0.1.weight
features.0.1.bias
features.1.0.block.0.0.weight
features.1.0.block.0.1.weight
features.1.0.block.0.1.bias
...
features.8.1.weight
features.8.1.bias
classifier.1.weight
classifier.1.bias
Now you can prune each layer you want as specified in the documentation :
module_to_prune = model.features[0][0]
prune.random_unstructured(module_to_prune, name="weight", amount=0.3)
Let's check if the specified weights got pruned, and if there's a new _orig in the named_parameters.
for name, param in model.named_parameters():
print(name)
output :
features.0.0.weight_orig
features.0.1.weight
features.0.1.bias
features.1.0.block.0.0.weight
features.1.0.block.0.1.weight