0

I followed this link, but its implemented in Keras.

Cannot add CRF layer on top of BERT in keras for NER

Model description

Is it possible to add simple custom pytorch-crf layer on top of TokenClassification model. It will make the model more robust.

from torchcrf import CRF

model_checkpoint = "dslim/bert-base-NER"
tokenizer = BertTokenizer.from_pretrained(model_checkpoint,add_prefix_space=True)
config = BertConfig.from_pretrained(model_checkpoint, output_hidden_states=True)
bert_model = BertForTokenClassification.from_pretrained(model_checkpoint,id2label=id2label,label2id=label2id,ignore_mismatched_sizes=True)


class BERT_CRF(nn.Module):
    
    def __init__(self, bert_model, num_labels):
        super(BERT_CRF, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.25)
        
        self.classifier = nn.Linear(4*768, num_labels)

        self.crf = CRF(num_labels, batch_first = True)
    
    def forward(self, input_ids, attention_mask,  labels=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        
        **sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**
        sequence_output = self.dropout(sequence_output)
        
        emission = self.classifier(sequence_output) # [32,256,17]
        labels=labels.reshape(attention_mask.size()[0],attention_mask.size()[1])
        
        if labels is not None:    
            loss = -self.crf(log_soft(emission, 2), labels, mask=attention_mask.type(torch.uint8), reduction='mean')
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return [loss, prediction]
                
        else:         
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return prediction

args = TrainingArguments(
    "spanbert_crf_ner-pos2",
    # evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    per_device_train_batch_size=8,
    # per_device_eval_batch_size=32
    fp16=True
    # bf16=True #Ampere GPU
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data,
    # eval_dataset=train_data,
    # data_collator=data_collator,
    # compute_metrics=compute_metrics,
    tokenizer=tokenizer)

I get error on line **sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**

As outputs = self.bert(input_ids, attention_mask=attention_mask) gives the logits for tokenclassification. How can we get hidden states so that I can concate last 4 hidden states. so that I can dooutputs[1][-1]`?

Or is their easier way to implement BERT-CRF model?

MAC
  • 1,345
  • 2
  • 30
  • 60

0 Answers0