3

I am working on a multiclass, highly imbalanced classification problem. I use random forest as base classifier.

I would have to give report of model performance on the evaluation set considering multiple criteria (metrics: precision, recall conf_matrix, roc_auc).

Model train:

rf = RandomForestClassifier(()
rf.fit(train_X, train_y)

To obtain precision/recall and confusion_matrix, I go like:

pred = rf.predict(test_X)
precision = metrics.precision_score(y_test, pred)
recall  = metrics.recall_score(y_test, pred)
f1_score = metrics.f1_score(y_test, pred) 
confusion_matrix = metrics.confusion_matrix(y_test, pred)

Fine, but then computing roc_auc requires the prediction probability of classes and not the class labels. For that I must further do this:

y_prob = rf.predict_proba(test_X)
roc_auc = metrics.roc_auc_score(y_test, y_prob)

But then I'm worried here that the outcome produced first by rf.predict() may not be consistent with rf.predict_proba() so the roc_auc score I'm reporting. I know that calling predict several times will produce exactly the same result, but I'm concern predict then predict_proba might produce slightly different results, making it inappropriate to discuss together with the metrics above.

If that is the case, is there a way to control this, making sure the class probabilities used by predict() to decide predicted labels are exactly the same when I then call predict_proab?

arilwan
  • 3,374
  • 5
  • 26
  • 62
  • 1
    As a general comment, for highly imbalanced classifications accuracy, precision, recall and f1 score are poor metrics to evaluate your model. If you want to evaluate how well your model can distinguish between the various classes, focus on roc-auc. If you're trying to optimize for a business problem and know the cost of mistakes (e.g.: labeling A as B is 10x worse than labeling B as A), focus on [f_beta_score](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html). – Swier Mar 26 '21 at 11:42

1 Answers1

4

predict_proba() and predict() are consistent with eachother. In fact, predict uses predict_proba internally as can be seen here in the source code

Swier
  • 4,047
  • 3
  • 28
  • 52