2

How to get the roc auc score for multi-class classification in sklearn?

binary

# this works
roc_auc_score([0,1,1], [1,1,1])

multiclass

# this fails
from sklearn.metrics import roc_auc_score

ytest  = [0,1,2,3,2,2,1,0,1]
ypreds = [1,2,1,3,2,2,0,1,1]

roc_auc_score(ytest, ypreds,average='macro',multi_class='ovo')

# AxisError: axis 1 is out of bounds for array of dimension 1

I looked at the official documentation but could not solve the issue.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
BhishanPoudel
  • 15,974
  • 21
  • 108
  • 169

1 Answers1

5

roc_auc_score in the multilabel case expects binary label indicators with shape (n_samples, n_classes), it is way to get back to a one-vs-all fashion.

To do that easily, you can use label_binarize (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.label_binarize.html#sklearn.preprocessing.label_binarize).

For your code, it will be:

from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import label_binarize

# You need the labels to binarize
labels = [0, 1, 2, 3]

ytest  = [0,1,2,3,2,2,1,0,1]

# Binarize ytest with shape (n_samples, n_classes)
ytest = label_binarize(ytest, classes=labels)

ypreds = [1,2,1,3,2,2,0,1,1]

# Binarize ypreds with shape (n_samples, n_classes)
ypreds = label_binarize(ypreds, classes=labels)


roc_auc_score(ytest, ypreds,average='macro',multi_class='ovo')

Typically, here ypreds and yest become:

ytest
array([[1, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 0, 1, 0],
       [0, 0, 0, 1],
       [0, 0, 1, 0],
       [0, 0, 1, 0],
       [0, 1, 0, 0],
       [1, 0, 0, 0],
       [0, 1, 0, 0]])

ypreds
array([[0, 1, 0, 0],
       [0, 0, 1, 0],
       [0, 1, 0, 0],
       [0, 0, 0, 1],
       [0, 0, 1, 0],
       [0, 0, 1, 0],
       [1, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 1, 0, 0]])
Melkozaur
  • 126
  • 6
  • I'm using Python 3, and I ran your code above and got the following error: TypeError: roc_auc_score() got an unexpected keyword argument 'multi_class'. So I updated to scikit-learn 0.23.2 (had 0.23.1). Thanks for the post. – spacedustpi Nov 03 '20 at 19:27