You do this per layer, example on a Residual Net Block:
class Residual(hk.Module):
"""The Residual block of ResNet."""
def __init__(self, hidden_dim, use_1x1conv=False, strides=1,
init = hk.initializers.RandomNormal()):
super().__init__()
self.hidden_dim = hidden_dim
self.stride = strides
self.init = init
if use_1x1conv:
self.proj = hk.Conv2D(hidden_dim, 1,
stride=strides, w_init=self.init, b_init=self.init)
else:
self.proj = None
def __call__(self, x, is_training=True):
y = hk.Conv2D(self.hidden_dim, 3, padding=(1,1), stride=self.stride,
with_bias=False, w_init=self.init, b_init=self.init)(x)
y = hk.BatchNorm(True, True, 0.9)(y, is_training)
y = jax.nn.gelu(y)
y = hk.Conv2D(self.hidden_dim, 3, padding=(1,1), with_bias=False,
w_init=self.init, b_init=self.init)(y)
y = hk.BatchNorm(True, True, 0.9)(y, is_training)
if self.proj:
x = self.proj(x)
return jax.nn.gelu(x + y)