2

I am training a torch model, where I want to freeze (and later unfreeze) certain parameters. I was under the impression, that simply setting param.requires_grad = False would accomplish this. This does not seem to be the case for optimizers with momentum. I know that I can either instantiate a new optimizer or change the parameters of the existing one, but neither would allow me to unfreeze parameters (easily) and without keeping an extra reference to all the parameters that the optimizer was changing initially.

I think the desired result could be achieved by setting the momentum_buffer in the state of the optimizer to zero, but I am not sure how to do this, at it can not be easily accessed.

The code bellow can be used to reproduce the effects, with both known "solutions" commented out.

import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view((x.size()[0], -1))

def main():
    data = torchvision.datasets.MNIST("./data",download=True,
                                       transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                        ]))
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=1000,
                                              shuffle=True)

    net=nn.Sequential(*[Flatten(),
                    nn.Linear(28*28,100),
                    nn.ReLU(),
                    nn.Linear(100,10)])
    opt=torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

    for e in range(2):
        old_params = [p.clone() for p in net.parameters()]
        if e == 1:
            for j,p in enumerate(net.parameters()):
                if j<2:
                    p.requires_grad = False
            # opt.param_groups[0]['params'] = opt.param_groups[0]['params'][2:]

        # opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

        for data, label in tqdm(data_loader):
            loss=torch.nn.functional.cross_entropy(net(data),label)
            opt.zero_grad()
            loss.backward()
            opt.step()
        print(loss)

        new_params=[p.clone() for p in net.parameters()]
        change = [(~(p1 == p2).all()).item() for p1, p2 in zip(old_params, new_params)]
        print("Epoch: %d \t params changed: %s" % (e, change))
        print([p.requires_grad for p in net.parameters()])


if __name__ == '__main__':
    main()
Dillmann
  • 81
  • 5

0 Answers0