Here I learn from the paper called Deep compression [Han et. al.] using resnet18
I also work the following code, the weight times the mask so that it is the after_weight pruned by the k% lowest weight to zero. But that code doesn't work for me. Any efficient solution?
prune = float(0.1)
def prune_weights(torchweights):
weights=np.abs(torchweights.cpu().numpy());
weightshape=weights.shape
rankedweights=weights.reshape(weights.size).argsort()#.reshape(weightshape)
num = weights.size
prune_num = int(np.round(num*prune))
count=0
masks = np.zeros_like(rankedweights)
for n, rankedweight in enumerate(rankedweights):
if rankedweight > prune_num:
masks[n]=1
else: count+=1
print("total weights:", num)
print("weights pruned:",count)
masks=masks.reshape(weightshape)
weights=masks*weights
return torch.from_numpy(weights).cuda(), masks
# prune weights
# The pruned weight location is saved in the addressbook and maskbook.
# These will be used during training to keep the weights zero.
addressbook=[]
maskbook=[]
for k, v in net.state_dict().items():
if "conv2" in k:
addressbook.append(k)
print("pruning layer:",k)
weights=v
weights, masks = prune_weights(weights)
maskbook.append(masks)
checkpoint['net'][k] = weights
checkpoint['address'] = addressbook
checkpoint['mask'] = maskbook
net.load_state_dict(checkpoint['net'])