0

I'm attempting to make a precision-recall plot, which is based on https://stats.stackexchange.com/questions/186337/average-roc-for-repeated-10-fold-cross-validation-with-probability-estimates and using their method, I've come up with the following Minimal working example (the original is much larger, of course):

import matplotlib.pyplot as plt
import numpy as np

base_x = np.linspace(0, 1, 101)

plt.figure()
ys = []
x = [1,0.96,0.96,0.96,0.96,0.96,0.96,0.96,0.96,0.96,0.92,0.92,0.92,0.88,0.88,0.88,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.8,0.8,0.76,0.76,0.76,0.76,0.76,0.76,0.72,0.68,0.68,0.64,0.64,0.64,0.64,0.64,0.64,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.56,0.56,0.56,0.52,0.52,0.52,0.52,0.52,0.52,0.52,0.52,0.52,0.48,0.48,0.48,0.48,0.48,0.48,0.48,0.44,0.4,0.36,0.28,0.28,0.28,0.28,0.28,0.28,0.24,0.24,0.2,0.16,0.16,0.12,0.12,0.08,0.04,0]
y = [0.115207373271889,0.111111111111111,0.111627906976744,0.11214953271028,0.112676056338028,0.113207547169811,0.113744075829384,0.114285714285714,0.114832535885167,0.115384615384615,0.111111111111111,0.111650485436893,0.11219512195122,0.107843137254902,0.108374384236453,0.108910891089109,0.104477611940299,0.105,0.105527638190955,0.106060606060606,0.106598984771574,0.107142857142857,0.128834355828221,0.12962962962963,0.130434782608696,0.13125,0.132075471698113,0.132911392405063,0.133757961783439,0.134615384615385,0.135483870967742,0.136363636363636,0.137254901960784,0.138157894736842,0.139072847682119,0.14,0.140939597315436,0.141891891891892,0.142857142857143,0.143835616438356,0.154411764705882,0.155555555555556,0.156716417910448,0.157894736842105,0.162790697674419,0.1640625,0.165354330708661,0.166666666666667,0.184210526315789,0.185840707964602,0.1875,0.19047619047619,0.192307692307692,0.202127659574468,0.204301075268817,0.206521739130435,0.208791208791209,0.211111111111111,0.213483146067416,0.204545454545455,0.195402298850575,0.197674418604651,0.188235294117647,0.19047619047619,0.2,0.20253164556962,0.205128205128205,0.207792207792208,0.197368421052632,0.217391304347826,0.220588235294118,0.223880597014925,0.227272727272727,0.245901639344262,0.25,0.254237288135593,0.258620689655172,0.272727272727273,0.288461538461538,0.294117647058824,0.3,0.306122448979592,0.3125,0.319148936170213,0.304347826086957,0.311111111111111,0.318181818181818,0.302325581395349,0.30952380952381,0.317073170731707,0.325,0.333333333333333,0.342105263157895,0.351351351351351,0.393939393939394,0.40625,0.4,0.413793103448276,0.428571428571429,0.444444444444444,0.461538461538462,0.521739130434783,0.545454545454545,0.523809523809524,0.526315789473684,0.5,0.4375,0.466666666666667,0.5,0.538461538461538,0.583333333333333,0.636363636363636,0.6,0.666666666666667,0.625,0.571428571428571,0.666666666666667,0.75,1,1,1,1]
plt.plot(x, y, 'blue', alpha=0.15)
y = np.interp(base_x, y, x)
ys.append(y)
x = [1,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.851851851851852,0.851851851851852,0.851851851851852,0.851851851851852,0.851851851851852,0.851851851851852,0.851851851851852,0.814814814814815,0.814814814814815,0.814814814814815,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.740740740740741,0.740740740740741,0.703703703703704,0.703703703703704,0.666666666666667,0.666666666666667,0.666666666666667,0.666666666666667,0.62962962962963,0.62962962962963,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.555555555555556,0.518518518518518,0.481481481481481,0.481481481481481,0.481481481481481,0.481481481481481,0.481481481481481,0.444444444444444,0.444444444444444,0.444444444444444,0.444444444444444,0.444444444444444,0.444444444444444,0.444444444444444,0.407407407407407,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.333333333333333,0.296296296296296,0.296296296296296,0.259259259259259,0.259259259259259,0.259259259259259,0.259259259259259,0.259259259259259,0.222222222222222,0.185185185185185,0.185185185185185,0.185185185185185,0.148148148148148,0.148148148148148,0.111111111111111,0.0740740740740741,0.037037037037037,0,0]
y = [0.164634146341463,0.159509202453988,0.160493827160494,0.161490683229814,0.1625,0.163522012578616,0.164556962025316,0.165605095541401,0.166666666666667,0.167741935483871,0.168831168831169,0.169934640522876,0.171052631578947,0.173333333333333,0.174496644295302,0.175675675675676,0.17687074829932,0.178082191780822,0.179310344827586,0.180555555555556,0.183098591549296,0.184397163120567,0.185714285714286,0.18705035971223,0.188405797101449,0.182481751824818,0.183823529411765,0.185185185185185,0.186567164179104,0.18796992481203,0.189393939393939,0.190839694656489,0.192307692307692,0.193798449612403,0.1953125,0.196850393700787,0.198412698412698,0.2,0.201612903225806,0.203252032520325,0.204918032786885,0.206611570247934,0.208333333333333,0.210084033613445,0.211864406779661,0.213675213675214,0.21551724137931,0.217391304347826,0.219298245614035,0.221238938053097,0.214285714285714,0.216216216216216,0.218181818181818,0.220183486238532,0.222222222222222,0.224299065420561,0.226415094339623,0.228571428571429,0.230769230769231,0.233009708737864,0.235294117647059,0.237623762376238,0.24,0.242424242424242,0.244897959183673,0.247422680412371,0.25,0.252631578947368,0.25531914893617,0.258064516129032,0.260869565217391,0.263736263736264,0.266666666666667,0.269662921348315,0.261363636363636,0.264367816091954,0.267441860465116,0.270588235294118,0.273809523809524,0.27710843373494,0.280487804878049,0.271604938271605,0.275,0.278481012658228,0.269230769230769,0.272727272727273,0.276315789473684,0.28,0.283783783783784,0.287671232876712,0.291666666666667,0.295774647887324,0.3,0.304347826086957,0.308823529411765,0.313432835820896,0.318181818181818,0.323076923076923,0.328125,0.317460317460317,0.32258064516129,0.311475409836066,0.316666666666667,0.305084745762712,0.310344827586207,0.315789473684211,0.321428571428571,0.309090909090909,0.314814814814815,0.30188679245283,0.307692307692308,0.313725490196078,0.32,0.326530612244898,0.333333333333333,0.340425531914894,0.347826086956522,0.355555555555556,0.363636363636364,0.372093023255814,0.357142857142857,0.341463414634146,0.325,0.333333333333333,0.342105263157895,0.351351351351351,0.361111111111111,0.342857142857143,0.352941176470588,0.363636363636364,0.375,0.387096774193548,0.4,0.413793103448276,0.392857142857143,0.37037037037037,0.384615384615385,0.4,0.416666666666667,0.434782608695652,0.454545454545455,0.476190476190476,0.5,0.526315789473684,0.5,0.470588235294118,0.5,0.466666666666667,0.5,0.538461538461538,0.583333333333333,0.636363636363636,0.6,0.555555555555556,0.625,0.714285714285714,0.666666666666667,0.8,0.75,0.666666666666667,0.5,0,1]
plt.plot(x, y, 'blue', alpha=0.15)
y = np.interp(base_x, y, x)
ys.append(y)

ys = np.array(ys)
mean_ys = ys.mean(axis=0)
std = ys.std(axis=0)
ys_upper = np.minimum(mean_ys + std, 1)
ys_lower = mean_ys - std
plt.fill_between(base_x, ys_lower, ys_upper, color='blue', alpha=0.3)

plt.plot([0, 1], [1, 0],"r--")
plt.savefig('images/average_pre_roc.svg', bbox_inches='tight', pad_inches = 0.1)

which makes the following plot

enter image description here

but the problem is that the plot clearly isn't posting the means correctly. I've tried flipping around x and y, but I'm at a loss why the means are showing up in the wrong place.

I've looked at How to Plot PR-Curve Over 10 folds of Cross Validation in Scikit-Learn and Precision Recall curve with n-fold cross validation showing standard deviation but I don't see how they apply in my case.

How can I plot the means/std correctly in the plot?

con
  • 5,767
  • 8
  • 33
  • 62

1 Answers1

1

roc_curve from sklearn outputs ROC points ascending, but precision_recall_curve returns descending for precision-recall curves. Reversing x and y thus:

x.reverse()
y.reverse()

fixed the problem.

The entire script is as follows:

import matplotlib.pyplot as plt
import numpy as np

base_x = np.linspace(0, 1, 101)

plt.figure()
ys = []
x = [1,0.96,0.96,0.96,0.96,0.96,0.96,0.96,0.96,0.96,0.92,0.92,0.92,0.88,0.88,0.88,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.84,0.8,0.8,0.76,0.76,0.76,0.76,0.76,0.76,0.72,0.68,0.68,0.64,0.64,0.64,0.64,0.64,0.64,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.6,0.56,0.56,0.56,0.52,0.52,0.52,0.52,0.52,0.52,0.52,0.52,0.52,0.48,0.48,0.48,0.48,0.48,0.48,0.48,0.44,0.4,0.36,0.28,0.28,0.28,0.28,0.28,0.28,0.24,0.24,0.2,0.16,0.16,0.12,0.12,0.08,0.04,0]
y = [0.115207373271889,0.111111111111111,0.111627906976744,0.11214953271028,0.112676056338028,0.113207547169811,0.113744075829384,0.114285714285714,0.114832535885167,0.115384615384615,0.111111111111111,0.111650485436893,0.11219512195122,0.107843137254902,0.108374384236453,0.108910891089109,0.104477611940299,0.105,0.105527638190955,0.106060606060606,0.106598984771574,0.107142857142857,0.128834355828221,0.12962962962963,0.130434782608696,0.13125,0.132075471698113,0.132911392405063,0.133757961783439,0.134615384615385,0.135483870967742,0.136363636363636,0.137254901960784,0.138157894736842,0.139072847682119,0.14,0.140939597315436,0.141891891891892,0.142857142857143,0.143835616438356,0.154411764705882,0.155555555555556,0.156716417910448,0.157894736842105,0.162790697674419,0.1640625,0.165354330708661,0.166666666666667,0.184210526315789,0.185840707964602,0.1875,0.19047619047619,0.192307692307692,0.202127659574468,0.204301075268817,0.206521739130435,0.208791208791209,0.211111111111111,0.213483146067416,0.204545454545455,0.195402298850575,0.197674418604651,0.188235294117647,0.19047619047619,0.2,0.20253164556962,0.205128205128205,0.207792207792208,0.197368421052632,0.217391304347826,0.220588235294118,0.223880597014925,0.227272727272727,0.245901639344262,0.25,0.254237288135593,0.258620689655172,0.272727272727273,0.288461538461538,0.294117647058824,0.3,0.306122448979592,0.3125,0.319148936170213,0.304347826086957,0.311111111111111,0.318181818181818,0.302325581395349,0.30952380952381,0.317073170731707,0.325,0.333333333333333,0.342105263157895,0.351351351351351,0.393939393939394,0.40625,0.4,0.413793103448276,0.428571428571429,0.444444444444444,0.461538461538462,0.521739130434783,0.545454545454545,0.523809523809524,0.526315789473684,0.5,0.4375,0.466666666666667,0.5,0.538461538461538,0.583333333333333,0.636363636363636,0.6,0.666666666666667,0.625,0.571428571428571,0.666666666666667,0.75,1,1,1,1]
x.reverse()
y.reverse()
plt.plot(x, y, 'blue', alpha=0.15)
y = np.interp(base_x, x, y)
ys.append(y)
x = [1,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.962962962962963,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.925925925925926,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.888888888888889,0.851851851851852,0.851851851851852,0.851851851851852,0.851851851851852,0.851851851851852,0.851851851851852,0.851851851851852,0.814814814814815,0.814814814814815,0.814814814814815,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.777777777777778,0.740740740740741,0.740740740740741,0.703703703703704,0.703703703703704,0.666666666666667,0.666666666666667,0.666666666666667,0.666666666666667,0.62962962962963,0.62962962962963,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.592592592592593,0.555555555555556,0.518518518518518,0.481481481481481,0.481481481481481,0.481481481481481,0.481481481481481,0.481481481481481,0.444444444444444,0.444444444444444,0.444444444444444,0.444444444444444,0.444444444444444,0.444444444444444,0.444444444444444,0.407407407407407,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.37037037037037,0.333333333333333,0.296296296296296,0.296296296296296,0.259259259259259,0.259259259259259,0.259259259259259,0.259259259259259,0.259259259259259,0.222222222222222,0.185185185185185,0.185185185185185,0.185185185185185,0.148148148148148,0.148148148148148,0.111111111111111,0.0740740740740741,0.037037037037037,0,0]
y = [0.164634146341463,0.159509202453988,0.160493827160494,0.161490683229814,0.1625,0.163522012578616,0.164556962025316,0.165605095541401,0.166666666666667,0.167741935483871,0.168831168831169,0.169934640522876,0.171052631578947,0.173333333333333,0.174496644295302,0.175675675675676,0.17687074829932,0.178082191780822,0.179310344827586,0.180555555555556,0.183098591549296,0.184397163120567,0.185714285714286,0.18705035971223,0.188405797101449,0.182481751824818,0.183823529411765,0.185185185185185,0.186567164179104,0.18796992481203,0.189393939393939,0.190839694656489,0.192307692307692,0.193798449612403,0.1953125,0.196850393700787,0.198412698412698,0.2,0.201612903225806,0.203252032520325,0.204918032786885,0.206611570247934,0.208333333333333,0.210084033613445,0.211864406779661,0.213675213675214,0.21551724137931,0.217391304347826,0.219298245614035,0.221238938053097,0.214285714285714,0.216216216216216,0.218181818181818,0.220183486238532,0.222222222222222,0.224299065420561,0.226415094339623,0.228571428571429,0.230769230769231,0.233009708737864,0.235294117647059,0.237623762376238,0.24,0.242424242424242,0.244897959183673,0.247422680412371,0.25,0.252631578947368,0.25531914893617,0.258064516129032,0.260869565217391,0.263736263736264,0.266666666666667,0.269662921348315,0.261363636363636,0.264367816091954,0.267441860465116,0.270588235294118,0.273809523809524,0.27710843373494,0.280487804878049,0.271604938271605,0.275,0.278481012658228,0.269230769230769,0.272727272727273,0.276315789473684,0.28,0.283783783783784,0.287671232876712,0.291666666666667,0.295774647887324,0.3,0.304347826086957,0.308823529411765,0.313432835820896,0.318181818181818,0.323076923076923,0.328125,0.317460317460317,0.32258064516129,0.311475409836066,0.316666666666667,0.305084745762712,0.310344827586207,0.315789473684211,0.321428571428571,0.309090909090909,0.314814814814815,0.30188679245283,0.307692307692308,0.313725490196078,0.32,0.326530612244898,0.333333333333333,0.340425531914894,0.347826086956522,0.355555555555556,0.363636363636364,0.372093023255814,0.357142857142857,0.341463414634146,0.325,0.333333333333333,0.342105263157895,0.351351351351351,0.361111111111111,0.342857142857143,0.352941176470588,0.363636363636364,0.375,0.387096774193548,0.4,0.413793103448276,0.392857142857143,0.37037037037037,0.384615384615385,0.4,0.416666666666667,0.434782608695652,0.454545454545455,0.476190476190476,0.5,0.526315789473684,0.5,0.470588235294118,0.5,0.466666666666667,0.5,0.538461538461538,0.583333333333333,0.636363636363636,0.6,0.555555555555556,0.625,0.714285714285714,0.666666666666667,0.8,0.75,0.666666666666667,0.5,0,1]
x.reverse()
y.reverse()
plt.plot(x, y, 'blue', alpha=0.15)
y = np.interp(base_x, x, y)
ys.append(y)

ys = np.array(ys)
mean_ys = ys.mean(axis=0)
std = ys.std(axis=0)
ys_upper = np.minimum(mean_ys + std, 1)
ys_lower = mean_ys - std
plt.fill_between(base_x, ys_lower, ys_upper, color='blue', alpha=0.3)

plt.plot([0, 1], [1, 0],"r--")
plt.savefig('images/average_pre_roc.png', bbox_inches='tight', pad_inches = 0.1)

which produces the following image:

enter image description here

con
  • 5,767
  • 8
  • 33
  • 62