0

I am training a MLM model using Pytorch Trainer API. Here is my initial code.

data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)


class SEDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
        
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        return item

    def __len__(self):
        return len(self.encodings["attention_mask"])

train_data = SEDataset(train_encodings)
print("train_data prepared")


training_args = tr.TrainingArguments(

     output_dir='results_mlm_mmt2'
    ,logging_dir='logs_mlm_mmt2'        # directory for storing logs
    ,save_strategy="epoch"
    ,learning_rate=2e-5
    ,logging_steps=40000
    ,overwrite_output_dir=True
    ,num_train_epochs=10
    ,per_device_train_batch_size=32
    ,prediction_loss_only=True
    ,gradient_accumulation_steps=2
    ,fp16=True
)



trainer = tr.Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data
)

The above code works fine but I want to include few things:

  1. How can I include validation text data to it and in which format? Do I also need to pass labels for validation set?

  2. How can I include some metrics related to MLM to get printed after every #steps?

MAC
  • 1,345
  • 2
  • 30
  • 60

1 Answers1

0
  1. The Trainer class takes eval_dataset as argument, which allows you to pass validation data. It should be a Dataset object, like your train_data object.

  2. It also takes compute_metrics as argument, which is a function you can overwrite to manually define the metrics you want to display (see for example here).

Clef.
  • 487
  • 3
  • 14
  • MLM is different than classification I guess? Will these metrics work ? – MAC Aug 05 '22 at 08:10
  • You have to define the metrics that would be relevant for you in this function. You can also take a look at the HF Evaluate lib which might help. – Clef. Aug 05 '22 at 08:38