I am trying to build a model for learning assigned scores (real numbers) to some sentences in a data set. I use RNNs (in PyTorch) for this purpose. I have defined a model:
class RNNModel1(nn.Module):
def forward(self, input ,hidden_0):
embedded = self.embedding(input)
output, hidden = self.rnn(embedded, hidden_0)
output=self.linear(hidden)
return output , hidden
Train function is as:
def train(model,optimizer,criterion,BATCH_SIZE,train_loader,clip):
model.train(True)
total_loss = 0
hidden = model._init_hidden(BATCH_SIZE)
for i, (batch_of_data, batch_of_labels) in enumerate(train_loader, 1):
hidden=hidden.detach()
model.zero_grad()
output,hidden= model(batch_of_data,hidden)
loss = criterion(output, sorted_batch_target_scores)
total_loss += loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), clip)
optimizer.step()
return total_loss/len(train_loader.dataset)
when I run the code I receive this error:
RuntimeError: Expected hidden size (2, 24, 50), got (2, 30, 50)
Batch size=30, Hidden size=50, Number of Layers=1, Bidirectional=True.
I receive that error in the last batch of data. I checked the description of RNNs in PyTorch to solve this problem. RNNs in PyTorch have two input arguments and two output arguments. The input arguments are input and h_0. h_0 is a tensor includes initial hidden state for each element in batch of size(num_layers*num_directions, batch, hidden size). The output arguments are output ans h_n. h_n is a tensor includes hidden state for t=seq_len of size (num_layers*num_directions, batch, hidden size).
in all batches (except the last batch) the size of h_0 and h_n is the same. but in the last batch, perhaps number of elements is less than batch size. Thesefore the size of h_n is (num_layersnum_directions, remained_elements_in_last_batch, hidden size) but the size of h_0 is still (num_layersnum_directions, batch_size, hidden size).
So I receive that error in the last batch of data.
How can I solve this problem and handle the situation in which the size of h_0 and h_n is different?
Thanks in advance.