74

What is a multi-headed model in deep learning?

The only explanation I found so far is this: Every model might be thought of as a backbone plus a head, and if you pre-train backbone and put a random head, you can fine tune it and it is a good idea
Can someone please provide a more detailed explanation.

spacer.34
  • 864
  • 1
  • 6
  • 10

2 Answers2

93

The explanation you found is accurate. Depending on what you want to predict on your data you require an adequate backbone network and a certain amount of prediction heads.

For a basic classification network for example you can view ResNet, AlexNet, VGGNet, Inception,... as the backbone and the fully connected layer as the sole prediction head.

A good example for a problem where you need multiple-heads is localization, where you not only want to classify what is in the image but also want to localize the object (find the coordinates of the bounding box around it).

The image below shows the general architecture enter image description here

The backbone network ("convolution and pooling") is responsible for extracting a feature map from the image that contains higher level summarized information. Each head uses this feature map as input to predict its desired outcome.

The loss that you optimize for during training is usually a weighted sum of the individual losses for each prediction head.

SaiBot
  • 3,595
  • 1
  • 13
  • 19
  • 1
    So as far as I understand each 'head' is responsible for a specific task and the final model on which we are fitting the data is the blend of those 'heads'? – spacer.34 May 06 '19 at 11:53
  • @zoandr correct. I added a bit more information on this. – SaiBot May 06 '19 at 11:59
  • If I have to solve a multi-label classification problem, does that mean I have to use a multi-headed model? – spacer.34 May 06 '19 at 12:07
  • @zoandr yes you can do that, however you could also transform the problem to a multi-class classification problem and go back to one head. – SaiBot May 06 '19 at 12:10
  • But in the case of multi-label classification, what are the 'heads'? I used GloVe model for vectorization of text and then LabelPowerset and RandomForestClassifier for the fitting. Are those three the 'heads' in my case? – spacer.34 May 06 '19 at 12:15
  • @zoandr there is a lot to unpack here. Not sure if the comment section is suited for that, but: GloVe gives you word vectors it is not clear how you get text vectors from those. LabelPwerset does convert multi-label to multi-class. RandomForestClassifier is a classifier. To me this sounds like a three step approach with a pretrained word to vector model, label transformation and classification. "Prediction heads" is term that is usually used in the neural network end-to-end world. – SaiBot May 06 '19 at 12:34
  • Thanks a lot for your detailed explanations! – spacer.34 May 06 '19 at 12:46
  • I cannot find where does this idea come from. Could you tell me which paper proposed this / coined the term "multi-headed model"? – Shuai May 13 '21 at 11:40
  • Image source for those looking: https://towardsdatascience.com/detection-and-segmentation-through-convnets-47aa42de27ea (source your images please) – Logan Cundiff Dec 12 '21 at 19:52
9

Head is the top of a network. For instance, on the bottom (where data comes in) you take convolution layers of some model, say resnet. If you call ConvLearner.pretrained, CovnetBuilder will build a network with appropriate head to your data in Fast.ai (if you are working on a classification problem, it will create a head with a cross entropy loss, if you are working on a regression problem, it will create a head suited to that).

But you could build a model that has multiple heads. The model could take inputs from the base network (resnet conv layers) and feed the activations to some model, say head1 and then same data to head2. Or you could have some number of shared layers built on top of resnet and only those layers feeding to head1 and head2.

You could even have different layers feed to different heads! There are some nuances to this (for instance, with regards to the fastai lib, ConvnetBuilder will add an AdaptivePooling layer on top of the base network if you don’t specify the custom_head argument and if you do it won’t) but this is the general picture.

unrealapex
  • 578
  • 9
  • 23
Alberto Tono
  • 314
  • 3
  • 12