2

Here is a very small example using precision_recall_curve():

from sklearn.metrics import precision_recall_curve, precision_score, recall_score
y_true = [0, 1]
y_predict_proba = [0.25,0.75]
precision, recall, thresholds = precision_recall_curve(y_true, y_predict_proba)
precision, recall

which results in:

(array([1., 1.]), array([1., 0.]))

The above does not match the "manual" calculation which follows.

There are three possible class vectors depending on threshold: [0,0] (when the threshold is > 0.75) , [0,1] (when the threshold is between 0.25 and 0.75), and [1,1] (when the threshold is <0.25). We have to discard [0,0] because it gives an undefined precision (divide by zero). So, applying precision_score() and recall_score() to the other two:

y_predict_class=[0,1]
precision_score(y_true, y_predict_class), recall_score(y_true, y_predict_class)

which gives:

(1.0, 1.0)

and

y_predict_class=[1,1]
precision_score(y_true, y_predict_class), recall_score(y_true, y_predict_class)

which gives

(0.5, 1.0)

This seems not to match the output of precision_recall_curve() (which for example did not produce a 0.5 precision value).

Am I missing something?

dabru
  • 786
  • 8
  • 8

1 Answers1

0

I know I am late, but I had your same doubt that I have eventually solved. The main point here is that precision_recall_curve() does not output precision and recall values anymore after full recall is obtained the first time; moreover, it concatenates a 0 to the recall array and a 1 to the precision array so as to let the curve start in correspondence of the y-axis.

In your specific example, you'll have effectively two arrays done like this (they are ordered the other way around because of the specific implementation of sklearn):

precision, recall
(array([1., 0.5]), array([1., 1.]))

Then, the values of the two arrays which do correspond to the second occurrence of full recall are omitted and 1 and 0 values (for precision and recall, respectively) are concatenated as described above:

precision, recall
(array([1., 1.]), array([1., 0.]))

I have tried to explain it here in full details; another useful link is certainly this one.

amiola
  • 2,593
  • 1
  • 11
  • 25