2

I'm beginner learning to build a standard transformer model based on PyTorch to solve an univariate sequence-to-sequence regression problem. The codes are written referring to the tutorial of PyTorch, but it turns out the training/validation error is quite different from the testing error.

During the training, it goes like:

        for src, tgt in train_loader:

        optimizer.zero_grad()
        output = net(src=src, tgt=tgt, device=device)
        loss = criterion(output[:,:-1,:], tgt[:,1:,:])   #is this correct?
        loss.backward()
        optimizer.step()

where the target sequence tgt is prefixed with a fixed number (0.1) to mimic the SOS token, and the output sequence output is shifted as well to mimic the EOS token. The transformer net is trained with the triangular target mask to mimic the auto-regression during the inference when the targer sequence is not available.

During the training, it goes like:

     with torch.no_grad():
        for src, tgt in test_loader:

            net.eval()
                            
            outputs = torch.zeros(tgt.size())
            
            temp = (torch.rand(tgt.size())*2-1)
            temp[:,0,:] = 0.1*torch.ones(tgt[:,0,:].size())   #prefix to mimic SOS            
            
            for t in range(1, temp.size()[1]):   
                outputs = net(src=src, tgt=temp, device=device)
                temp[:,t,:] = outputs[:,t-1,:]      #is this correct?
                
            outputs = net(src, temp, device=device)      #is this correct?
            
            print(criterion(outputs[:,:-1,:], tgt[:,1:,:]))

During the training, the training loss and validation loss (based on MSE) drop and converge smoothly. However, the testing loss turns out to be much larger than the aforementioned. Could anyone check it out if this is the correct way to do the inference of transformer model?

(Btw, I couldn't find many examples for univariate sequence regression transformer models on Google, any recommended links will be really appreciated!)

Progman
  • 16,827
  • 6
  • 33
  • 48
Haoran Li
  • 21
  • 3

0 Answers0