1

Here I learn from the paper called Deep compression [Han et. al.] using resnet18

I also work the following code, the weight times the mask so that it is the after_weight pruned by the k% lowest weight to zero. But that code doesn't work for me. Any efficient solution?

prune = float(0.1)
def prune_weights(torchweights):
    weights=np.abs(torchweights.cpu().numpy());
    weightshape=weights.shape
    rankedweights=weights.reshape(weights.size).argsort()#.reshape(weightshape)
    
    num = weights.size
    prune_num = int(np.round(num*prune))
    count=0
    masks = np.zeros_like(rankedweights)
    for n, rankedweight in enumerate(rankedweights):
        if rankedweight > prune_num:
            masks[n]=1
        else: count+=1
    print("total weights:", num)
    print("weights pruned:",count)
    
    masks=masks.reshape(weightshape)
    weights=masks*weights
    
    return torch.from_numpy(weights).cuda(), masks

# prune weights
# The pruned weight location is saved in the addressbook and maskbook.
# These will be used during training to keep the weights zero.
addressbook=[]
maskbook=[]
for k, v in net.state_dict().items():
    if "conv2" in k:
        addressbook.append(k)
        print("pruning layer:",k)
        weights=v
        weights, masks = prune_weights(weights)
        maskbook.append(masks)
        checkpoint['net'][k] = weights
        
checkpoint['address'] = addressbook
checkpoint['mask'] = maskbook
net.load_state_dict(checkpoint['net'])

1 Answers1

1

You can use torch.nn.utils.prune.

It seems you want to remove 10% of every Conv2D layer. If that is the case, you can do it this way:

import torch
import torch.nn.utils.prune as prune

# load your model
net = ?

# in your example, you want to remove 10%
prune_perc = 0.1

for name, module in net.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=prune_perc)
Berriel
  • 12,659
  • 4
  • 43
  • 67
  • no really the same as that. According to the bottom k% percentage of weight and prune that, l1_unstructured is different. – Chris Cheng Aug 27 '21 at 03:21
  • @ChrisCheng can you please elaborate? how is it different? The [docs](https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.l1_unstructured.html) reads: "_Prunes [...] by removing the specified amount of (currently unpruned) units with the lowest L1-norm._" – Berriel Aug 27 '21 at 12:27