1

I want to add a classification layer in pytorch on top of the huggingface vilt transformer, so that I can classify my text labels.

Generally in normal settings vilt takes an image, question pair and outputs the answer of the question after forward pass

I Want to make the task a classification task instead of a text generation task. I have a set of labels which I want the vilt to assign which label has the highest probability of being the answer of the given question.

I'm completely new to the transformers and have very little idea of how this task can be achieved. Can someone please help me?

I checked this medium blog but couldn't make sense out of it.

user10418143
  • 220
  • 3
  • 11

1 Answers1

2

You can add your own Classification_Head on top of Vilt Model.

This is simply the overview, make changes as per you requirements

class ClassificationHead(nn.Module):
    def __init__(self, input_size, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)
    
    def forward(self, x):
        return self.fc(x)

# Define your number of classes
num_classes = ..  # Number of classes in your classification task

# Create the classification head
classification_head = ClassificationHead(vilt_model.config.hidden_size, num_classes)


# Training Loop
for epoch in range(num_epochs):
    for batch in dataloader:
        inputs = batch["input_ids"]
        labels = batch["labels"]

        # Forward pass
        outputs = vilt_model(inputs).last_hidden_state[:, 0, :]
        logits = classification_head(outputs)

        # Calculate loss
        loss = criterion(logits, labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Inference
with torch.no_grad():
    inputs = ..  # Prepare your input data
    outputs = vilt_model(inputs).last_hidden_state[:, 0, :]
    logits = classification_head(outputs)
    predicted_labels = logits.argmax(dim=1)

Anay
  • 741
  • 1
  • 5
  • 15