I want to generate a Precision-Recall curve with 5-fold cross-validation showing standard deviation as in the example ROC curve code here.
The code below (adapted from How to Plot PR-Curve Over 10 folds of Cross Validation in Scikit-Learn) gives a PR curves for each fold of cross-validation along with the mean PR curve. I wanted to also show the region of one standard deviation above and below the mean PR curve in grey. But it gives the following error (details in the link below the code):
ValueError: operands could not be broadcast together with shapes (91,) (78,)
import matplotlib.pyplot as plt
import numpy
from sklearn.datasets import make_blobs
from sklearn.metrics import precision_recall_curve, auc
from sklearn.model_selection import KFold
from sklearn.svm import SVC
X, y = make_blobs(n_samples=500, n_features=2, centers=2, cluster_std=10.0,
random_state=10)
k_fold = KFold(n_splits=5, shuffle=True, random_state=10)
predictor = SVC(kernel='linear', C=1.0, probability=True, random_state=10)
y_real = []
y_proba = []
precisions, recalls = [], []
for i, (train_index, test_index) in enumerate(k_fold.split(X)):
Xtrain, Xtest = X[train_index], X[test_index]
ytrain, ytest = y[train_index], y[test_index]
predictor.fit(Xtrain, ytrain)
pred_proba = predictor.predict_proba(Xtest)
precision, recall, _ = precision_recall_curve(ytest, pred_proba[:,1])
lab = 'Fold %d AUC=%.4f' % (i+1, auc(recall, precision))
plt.plot(recall, precision, alpha=0.3, label=lab)
y_real.append(ytest)
y_proba.append(pred_proba[:,1])
precisions.append(precision)
recalls.append(recall)
y_real = numpy.concatenate(y_real)
y_proba = numpy.concatenate(y_proba)
precision, recall, _ = precision_recall_curve(y_real, y_proba)
lab = 'Overall AUC=%.4f' % (auc(recall, precision))
plt.plot(recall, precision, lw=2,color='red', label=lab)
std_precision = np.std(precisions, axis=0)
tprs_upper = np.minimum(precisions[median] + std_precision, 1)
tprs_lower = np.maximum(precisions[median] - std_precision, 0)
plt.fill_between(recall_overall, upper_precision, lower_precision, alpha=0.5, linewidth=0, color='grey')
Error reported and Plot generated
Can you please suggest how I add to the following code to also show one standard deviation around the mean PR curve?