9

In the last few layers of sequence classification by HuggingFace, they took the first hidden state of the sequence length of the transformer output to be used for classification.

hidden_state = distilbert_output[0]  # (bs, seq_len, dim) <-- transformer output
pooled_output = hidden_state[:, 0]  # (bs, dim)           <-- first hidden state
pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
pooled_output = self.dropout(pooled_output)  # (bs, dim)
logits = self.classifier(pooled_output)  # (bs, dim)

Is there any benefit to taking the first hidden state over the last, average, or even the use of a Flatten layer instead?

doe
  • 113
  • 5

1 Answers1

6

Yes, this is directly related to the way that BERT is trained. Specifically, I encourage you to have a look at the original BERT paper, in which the authors introduce the meaning of the [CLS] token:

[CLS] is a special symbol added in front of every input example [...].

Specifically, it is used for classification purposes, and therefore the first and simplest choice for any fine-tuning for classification tasks. What your relevant code fragment is doing, is basically just extracting this [CLS] token.

Unfortunately, the DistilBERT documentation of Huggingface's library does not explicitly refer to this, but you rather have to check out their BERT documentation, where they also highlight some issues with the [CLS] token, analogous to your concerns:

Alongside MLM, BERT was trained using a next sentence prediction (NSP) objective using the [CLS] token as a sequence approximate. The user may use this token (the first token in a sequence built with special tokens) to get a sequence prediction rather than a token prediction. However, averaging over the sequence may yield better results than using the [CLS] token.

dennlinger
  • 9,890
  • 1
  • 42
  • 63
  • +1. If averaging over the embeddings of the sequence could yield better results, why those authors didn't adopt this approach? – avocado Aug 20 '21 at 21:56
  • 1
    I presume that the alternative is more compute-intensive and therefore not worth the (maybe only marginal) gains. – dennlinger Aug 22 '21 at 09:42