I create a GRU model in Jax using Flax and I initialize the model parameters using model.init as follows:
import jax.numpy as np
from jax import random
import flax.linen as nn
from jax.nn import initializers
class RNN(nn.Module):
n_RNN_units: int
@nn.compact
def __call__(self, carry, inputs):
carry, outputs = nn.GRUCell()(carry, inputs)
return carry, outputs
def init_state(self):
return nn.GRUCell.initialize_carry((), (), self.n_RNN_units, init_fn = initializers.zeros)
# instantiate an RNN (GRU) model
n_RNN_units = 200
model = RNN(n_RNN_units = n_RNN_units)
# initialize the parameters of the model (weights and biases)
data_dim = 20
params = model.init(carry = np.empty((n_RNN_units,)), inputs = np.empty((data_dim,)), rngs = {'params': random.PRNGKey(1)})
Unfortuantely for me, the FrozenDict params created by model.init only contains the weight and biases of the GRU, not the initial hidden state (carry). Is there a way that I can tell model.init 1) that I also want to learn the initial hidden state and 2) specify the initializer function for the initial hidden state.
Alternatively, if there is a better way to do this that does not involve using model.init, feel free to suggest that.
Thanks in advance