I have some big networks that I define with nn.Sequential
. I need to initialize the weights of each layer / network depending on some criteria. For instance, I want one network to output very large value.
I can do that easily by using .apply()
when I build the network.
I also need to be able to reset the weights at any time, and I need that the weights are reset using the same initialization functions I called with .apply()
.
This answer provides a clean way of resetting all weights, but it uses the default initializations. How can I make something similar that calls the same initializations I have used in the first place?
MWE below. In this case I initialize a simple network where the first layer has very small weights and the second has very large weights. I need that reset_all_weights
initializes the network the same way. And the approach has to be generalizable to any initialization function that I define.
import torch
import torch.nn as nn
def _init_normal(module, value):
module.bias.data.normal_(0., value)
module.weight.data.normal_(0., value)
def reset_all_weights(self):
''' https://stackoverflow.com/a/69905671/754136 '''
@torch.no_grad()
def weight_reset(m: nn.Module):
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
self.apply(fn=weight_reset)
m = nn.Sequential(
nn.Linear(5, 5).apply(lambda x: _init_normal(x, 0.0001)),
nn.ReLU(),
nn.Linear(5, 5).apply(lambda x: _init_normal(x, 10.)),
)
for p in m.parameters():
print(p)
# first layer weights are small, second are large
reset_all_weights(m)
for p in m.parameters():
print(p)
# however, now both layers have similar weights