0

In Pytorch the following code can be used to initialise a layer:

def init_layer(in_features, out_features):
 x = nn.Linear(in_features, out_features)
 limit = 1.0 / math.sqrt(in_features)
 x.weight = nn.Parameter(
    data=torch.distributions.uniform.Uniform(-limit, limit).sample(x.weight.shape), requires_grad=True
)
 return x

How to do the same thing using Jax & Haiku?

Thanks!

masha
  • 33
  • 3
  • I'm sorry, but you have to add what have you tried to do (and what errors you encountered) so far before asking this here. – tornikeo Feb 22 '22 at 20:24
  • Unfortunately, I have tried nothing so far. I just don't know from where to start. – masha Feb 23 '22 at 02:19
  • Try following [this](https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html) tutorial to the end first. If you still don't get it after that, I'll personally guide you through whatever you are doing. :) – tornikeo Feb 23 '22 at 09:07

1 Answers1

0

You do this per layer, example on a Residual Net Block:

class Residual(hk.Module):
    """The Residual block of ResNet."""
    def __init__(self, hidden_dim, use_1x1conv=False, strides=1,
                     init = hk.initializers.RandomNormal()):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.stride = strides
        self.init = init
        
        if use_1x1conv:
            self.proj = hk.Conv2D(hidden_dim, 1,
                        stride=strides, w_init=self.init, b_init=self.init)
        else:
            self.proj = None
            
    def __call__(self, x, is_training=True):
        
        y = hk.Conv2D(self.hidden_dim, 3, padding=(1,1), stride=self.stride,
                         with_bias=False, w_init=self.init, b_init=self.init)(x)
        y = hk.BatchNorm(True, True, 0.9)(y, is_training)
        y = jax.nn.gelu(y)
        
        y = hk.Conv2D(self.hidden_dim, 3, padding=(1,1), with_bias=False,
                         w_init=self.init, b_init=self.init)(y)
        y = hk.BatchNorm(True, True, 0.9)(y, is_training)
        
        
        if self.proj:
            x = self.proj(x)
        
        return jax.nn.gelu(x + y)