0

I am working on a GRU and when I try to make predictions I get an error indicating that I need to define h for forward(). I have tried several things and ran out of patience after googling and scouring stack overflow for hours.

This is the class:

class GRUNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, drop_prob = 0.2):
        super(GRUNet, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.gru = nn.GRU(input_dim, hidden_dim, n_layers, batch_first=True, dropout=drop_prob)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x, h):
        out, h = self.gru(x,h)
        out = self.fc(self.relu(out[:,-1]))
        return out, h
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device)
        return hidden

and then this is where I load the model and try to make a prediction. Both of these are in the same script.

inputs = np.load('.//Pred//input_list.npy')  
print(inputs.ndim, inputs.shape)
Gmodel = GRUNet(24,256,1,2)
Gmodel = torch.load('.//GRU//GRU_1028_48.pkl')
Gmodel.eval()
pred = Gmodel(inputs)

Without any other arguments to Gmodel I get the following:

Traceback (most recent call last):
  File ".\grunet.py", line 136, in <module>
    pred = Gmodel(inputs)
  File "C:\Users\ryang\Anaconda-3\envs\tf-gpu\lib\site-packages\torch\nn\modules\module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'h'
R Godbey
  • 76
  • 8

1 Answers1

0

You need to provide the hidden state as well which is usually initially all zeros or simply None!
That is you either need to explicitly provide one like this :

hidden_state = torch.zeros(size=(num_layers*direction, batch_size, hidden_dim)).to(device)
pred = Gmodel(inputs, hidden_state)

or simply do :

hidden_state = None 
pred = Gmodel(inputs, hidden_state)
Hossein
  • 24,202
  • 35
  • 119
  • 224
  • That's another question that you need to post separately. if this answer solves your initial issue please accept it as the answer so this question is done. then ask your new question separately and we try to answer it the best we can. – Hossein Oct 30 '20 at 06:23