Let's say I wanted to multiply all parameters of a neural network in PyTorch (an instance of a class inheriting from torch.nn.Module
) by 0.9
. How would I do that?
Asked
Active
Viewed 1.1k times
2

Mateen Ulhaq
- 24,552
- 19
- 101
- 135

the-bass
- 705
- 1
- 6
- 20
2 Answers
9
Let net
be an instance of a neural network nn.Module
.
Then, to multiply all parameters by 0.9
:
state_dict = net.state_dict()
for name, param in state_dict.items():
# Transform the parameter as required.
transformed_param = param * 0.9
# Update the parameter.
param.copy_(transformed_param)
If you want to only update weights instead of every parameter:
state_dict = net.state_dict()
for name, param in state_dict.items():
# Don't update if this is not a weight.
if not "weight" in name:
continue
# Transform the parameter as required.
transformed_param = param * 0.9
# Update the parameter.
param.copy_(transformed_param)

Mateen Ulhaq
- 24,552
- 19
- 101
- 135

the-bass
- 705
- 1
- 6
- 20
-
1does this make sense? I would expect some vector operation to perform this operation, otherwise python would kill the speed, and furthermore would this work on GPU? – Gulzar Feb 17 '19 at 15:15
-
1@Gulzar Performance impact is almost non-existent. `state_dict` is a small collection (e.g. 10) of references to large tensors (e.g. 16x512x64x64). The `* 0.9` operation is run by external C++ libraries which are usually "vectorized". Calling the `* 0.9` operation from Python is quite cheap (1 microsecond per call * 10 calls = 10 microseconds) since Python doesn't do any of the actual computation. – Mateen Ulhaq Jan 19 '22 at 07:21
-
BTW, couldn't one do `param.copy_(transformed_param)` to avoid looking up `state_dict[name]` again? I don't think `state_dict` overrides `__setattr__` anyways... in fact it's just an `OrderedDict` that is created only when `net.state_dict()` is called. – Mateen Ulhaq Jan 19 '22 at 07:39
0
A different way of achieving this is using tensor.parameters()
.
Initialize module:
>>> a = torch.nn.Linear(2, 2)
>>> a.state_dict()
OrderedDict([('weight',
tensor([[-0.1770, -0.2151],
[-0.6543, 0.6637]])),
('bias', tensor([-0.0524, 0.6807]))])
Change the parameters:
for p in a.parameters():
p.data *= 0
See the effect:
>>> a.state_dict()
OrderedDict([('weight',
tensor([[-0., -0.],
[-0., 0.]])),
('bias', tensor([-0., 0.]))])

Onno Eberhard
- 1,341
- 1
- 10
- 18