0

I have some big networks that I define with nn.Sequential. I need to initialize the weights of each layer / network depending on some criteria. For instance, I want one network to output very large value.

I can do that easily by using .apply() when I build the network.

I also need to be able to reset the weights at any time, and I need that the weights are reset using the same initialization functions I called with .apply().

This answer provides a clean way of resetting all weights, but it uses the default initializations. How can I make something similar that calls the same initializations I have used in the first place?

MWE below. In this case I initialize a simple network where the first layer has very small weights and the second has very large weights. I need that reset_all_weights initializes the network the same way. And the approach has to be generalizable to any initialization function that I define.

import torch
import torch.nn as nn

def _init_normal(module, value):
    module.bias.data.normal_(0., value)
    module.weight.data.normal_(0., value)


def reset_all_weights(self):
    ''' https://stackoverflow.com/a/69905671/754136 '''
    @torch.no_grad()
    def weight_reset(m: nn.Module):
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()
    self.apply(fn=weight_reset)


m = nn.Sequential(
  nn.Linear(5, 5).apply(lambda x: _init_normal(x, 0.0001)),
  nn.ReLU(),
  nn.Linear(5, 5).apply(lambda x: _init_normal(x, 10.)),
)

for p in m.parameters():
    print(p)

# first layer weights are small, second are large

reset_all_weights(m)

for p in m.parameters():
    print(p)

# however, now both layers have similar weights
Simon
  • 5,070
  • 5
  • 33
  • 59

0 Answers0