Here is some MWE of the iterative pruning code I'm trying to run:
import torch
from torch import nn
import numpy as np
import torch.nn.utils.prune as prune
class myNN(nn.Module):
def __init__(self,dims):
super(myNN, self).__init__()
self.fc1 = nn.Linear(dims[0],dims[1])
self.fc2 = nn.Linear(dims[1],dims[2])
self.fc3 = nn.Linear(dims[2],dims[3])
self.fc4 = nn.Linear(dims[3],dims[4])
def forward(self,x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = self.fc4(x)
return x
numIt = 10
for num_pruning in range(numIt):
#Prune network
parameters_to_prune = (
(myNN.fc1, 'weight'), (myNN.fc1, 'bias'),
(myNN.fc2, 'weight'), (myNN.fc2, 'bias'),
(myNN.fc3, 'weight'), (myNN.fc3, 'bias'),
(myNN.fc4, 'weight'), (myNN.fc4, 'bias'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.1,
)
#Train network...
However, as numIt
grows very large, I run out of memory. I have tried to tackle this issue by means of prune.remove()
. However, to apply it to every layer in the network, I define the following function:
def myPruneRemove(model):
for module in model.modules():
if isinstance(module, torch.nn.Linear):
prune.remove(module, 'weight')
prune.remove(module, 'bias')
Then, I do myPruneRemove(myNN)
inside the loop after global pruning. My question is if this way is the best to do iterative global pruning and avoid running out of memory. Does prune.remove()
keep track of which weights/biases have been pruned in each iteration so that the pruning is accumulated over iterations?