2

I am training a text classification model over a large set of data and I am using bert classifier (bert-base-uncased) of simpletransformer library. Simpletransformer retports by default mcc and eval_loss for evaluation during training and the test(eval) phase. I was able to set additional metrics such as acc, f1 etc. for the test phase (by sending extra metrics to the eval_model function), But I don't know how to tell simpletransformer to report these metrics during the training phase as well? Is it possible to do the same thing with train_model function?
It is worth mentioning that eval_during_training option is True.

It prints the mcc and eval_loss of the training for each checkpoint(in eval_results.txt in outputs) and I need other metrics to be reported in each checkpoint as well.

result, model_outputs, wrong_predictions = model.eval_model(eval_df, f1=f1_multiclass, acc=accuracy_score)

Thanks in advance

cheers

Firouziam
  • 777
  • 1
  • 9
  • 31

1 Answers1

2

After surfing the web, I couldn't find the answer to my question. So, I started looking at the source code. It turns out it is way simpler than I thought. To include more metrics during training you need to include them just the way you include them in the eval_model method. Here is a sample code that shows how to feed extra metrics to simpletransformer train_model and eval_model.

def f1_multiclass(labels, preds):
    return f1_score(labels, preds, average='weighted')

def prec_multiclass(labels, preds):
    return precision_score(labels, preds, average='weighted')

def recall_multiclass(labels, preds):
    return recall_score(labels, preds, average='weighted')
    

model.train_model(train_df, eval_df=test_df,
                  f1=f1_multiclass,
                  acc=accuracy_score,
                  prec=prec_multiclass,
                  recall=recall_multiclass,
                  cohen=cohen_kappa_score)

result, model_outputs, wrong_predictions = model.eval_model(test_df,
                                                            f1=f1_multiclass,
                                                            acc=accuracy_score,
                                                            prec=prec_multiclass,
                                                            recall=recall_multiclass,
                                                            cohen=cohen_kappa_score)
Dharman
  • 30,962
  • 25
  • 85
  • 135
Firouziam
  • 777
  • 1
  • 9
  • 31