0

I designed a network for a text classification problem. To do this, I'm using huggingface transformet's BERT model with a linear layer above that for fine-tuning. My problem is that the loss on the training set is decreasing which is fine, but when it comes to do the evaluation after each epoch on the development set, the loss is increasing with epochs. I'm posting my code to investigate if there's something wrong with it.

for epoch in range(1, args.epochs + 1):
    total_train_loss = 0
    trainer.set_train()

    for step, batch in enumerate(train_dataloader):
        loss = trainer.step(batch)
        total_train_loss += loss

    avg_train_loss = total_train_loss / len(train_dataloader)

    logger.info(('Training loss for epoch %d/%d: %4.2f') % (epoch, args.epochs, avg_train_loss))

    print("\n-------------------------------")
    logger.info('Start validation ...')
    trainer.set_eval()
    y_hat = list()
    y = list()
    total_dev_loss = 0
    for step, batch_val in enumerate(dev_dataloader):
        true_labels_ids, predicted_labels_ids, loss = trainer.validate(batch_val)
        total_dev_loss += loss
        y.extend(true_labels_ids)
        y_hat.extend(predicted_labels_ids)
    avg_dev_loss = total_dev_loss / len(dev_dataloader)
    print(("\n-Total dev loss: %4.2f on epoch %d/%d\n") % (avg_dev_loss, epoch, args.epochs))

print("Training terminated!")

Following is the trainer file, which I use for doing a forward pass on a given batch and then backpropagate accordingly.

class Trainer(object):
    def __init__(self, args, model, device, data_points, is_test=False, train_stats=None):
        self.args = args
        self.model = model
        self.device = device
        self.loss = nn.CrossEntropyLoss(reduction='none')

        if is_test:
            # Should load the model from checkpoint
            self.model.eval()       
            self.model.load_state_dict(torch.load(args.saved_model))
            logger.info('Loaded saved model from %s' % args.saved_model)

        else:
            self.model.train()
            self.optim = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
            total_steps = data_points * self.args.epochs
            self.scheduler = get_linear_schedule_with_warmup(self.optim, num_warmup_steps=0,
                                                             num_training_steps=total_steps)

    def step(self, batch):
        batch = tuple(t.to(self.device) for t in batch)
        batch_input_ids, batch_input_masks, batch_labels = batch
        self.model.zero_grad()
        outputs = self.model(batch_input_ids,
                             attention_mask=batch_input_masks,
                             labels=batch_labels)
        loss = self.loss(outputs, batch_labels)
        loss = loss.sum()
        (loss / loss.numel()).backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optim.step()
        self.scheduler.step()
        return loss

    def validate(self, batch):
        batch = tuple(t.to(self.device) for t in batch)
        batch_input_ids, batch_input_masks, batch_labels = batch
        with torch.no_grad():
            model_output = self.model(batch_input_ids,
                                 attention_mask=batch_input_masks,
                                 labels=batch_labels)

        predicted_label_ids = self._predict(model_output)
        label_ids = batch_labels.to('cpu').numpy()

        loss = self.loss(model_output, batch_labels)
        loss = loss.sum()

        return label_ids, predicted_label_ids, loss

    def _predict(self, logits):
        return np.argmax(logits.to('cpu').numpy(), axis=1)

Finally, the following is my model (i.e., Classifier) class:

import torch.nn as nn
from transformers import BertModel


class Classifier(nn.Module):

    def __init__(self, args, is_eval=False):
        super(Classifier, self).__init__()

        self.bert_model = BertModel.from_pretrained(
            args.init_checkpoint,
            output_attentions=False,
            output_hidden_states=True,
        )
        self.is_eval_mode = is_eval
        self.linear = nn.Linear(768, 2) # binary classification

    def switch_state(self):
        self.is_eval_mode = not self.is_eval_mode

    def forward(self, input_ids, attention_mask=None, labels=None):

        bert_outputs = self.bert_model(input_ids,
                                       token_type_ids=None,
                                       attention_mask=attention_mask)

        # Should give the logits to the the linear layer
        model_output = self.linear(bert_outputs[1])

        return model_output

For visualization the loss throughout the epochs:

enter image description here

inverted_index
  • 2,329
  • 21
  • 40

2 Answers2

2

When I've used Bert for text classification my model has generally behaved as you tell. In part this is expected because pre-trained models tend to require few epochs to fine-tune, actually if you check Bert's paper the number of epochs recommended for fine-tuning is between 2 and 4.

On the other hand, I've usually found the optimum at just 1 or 2 epochs, which coincides with your case also. My guess is: there is a trade-off when fine-tuning pre-trained models between fitting to your downstream task and forgetting the weights learned at pre-training. Depending on the data you have, the equilibrium point may happen sooner or later and overfitting starts after that. But this paragraph is speculation based on my experience.

Javier Beltrán
  • 128
  • 2
  • 10
1

When validation loss increases it means your model is overfitting

  • I'm sure that from what I see in visualizations, I'm overfitting on training data after possibly epoch 2. I'm not sure about the validity of my implementation, like if there's something wrong with it or so... – inverted_index May 03 '20 at 19:04