2

Let's say I wanted to multiply all parameters of a neural network in PyTorch (an instance of a class inheriting from torch.nn.Module) by 0.9. How would I do that?

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
the-bass
  • 705
  • 1
  • 6
  • 20

2 Answers2

9

Let net be an instance of a neural network nn.Module. Then, to multiply all parameters by 0.9:

state_dict = net.state_dict()

for name, param in state_dict.items():
    # Transform the parameter as required.
    transformed_param = param * 0.9

    # Update the parameter.
    param.copy_(transformed_param)

If you want to only update weights instead of every parameter:

state_dict = net.state_dict()

for name, param in state_dict.items():
    # Don't update if this is not a weight.
    if not "weight" in name:
        continue
    
    # Transform the parameter as required.
    transformed_param = param * 0.9

    # Update the parameter.
    param.copy_(transformed_param)
Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
the-bass
  • 705
  • 1
  • 6
  • 20
  • 1
    does this make sense? I would expect some vector operation to perform this operation, otherwise python would kill the speed, and furthermore would this work on GPU? – Gulzar Feb 17 '19 at 15:15
  • 1
    @Gulzar Performance impact is almost non-existent. `state_dict` is a small collection (e.g. 10) of references to large tensors (e.g. 16x512x64x64). The `* 0.9` operation is run by external C++ libraries which are usually "vectorized". Calling the `* 0.9` operation from Python is quite cheap (1 microsecond per call * 10 calls = 10 microseconds) since Python doesn't do any of the actual computation. – Mateen Ulhaq Jan 19 '22 at 07:21
  • BTW, couldn't one do `param.copy_(transformed_param)` to avoid looking up `state_dict[name]` again? I don't think `state_dict` overrides `__setattr__` anyways... in fact it's just an `OrderedDict` that is created only when `net.state_dict()` is called. – Mateen Ulhaq Jan 19 '22 at 07:39
0

A different way of achieving this is using tensor.parameters().

Initialize module:

>>> a = torch.nn.Linear(2, 2)
>>> a.state_dict()
OrderedDict([('weight',
              tensor([[-0.1770, -0.2151],
                      [-0.6543,  0.6637]])),
             ('bias', tensor([-0.0524,  0.6807]))])

Change the parameters:

for p in a.parameters():
    p.data *= 0

See the effect:

>>> a.state_dict()
OrderedDict([('weight',
              tensor([[-0., -0.],
                      [-0., 0.]])),
             ('bias', tensor([-0., 0.]))])
Onno Eberhard
  • 1,341
  • 1
  • 10
  • 18