1

I am trying to write a logistic regression with L1 regularization and implement the roc curve.I use stochastic gradient ascend of log-likelihood as the cost function. But the roc curve I plot is weird.

My data set is from http://archive.ics.uci.edu/ml/datasets/HTRU2

Here is the roc curve I draw

my roc curve

And the roc curve plot by the scikit-learn function roc curve plot by roc_curve function

and here is my roc curve code

def plot_roc_curve(X_test, y_true, theta):
    scores = numpy.squeeze(sigmoid(X_test, theta))
    k = 10
    threshold = numpy.arange(0, 1, 1./k)
    threshold = numpy.append(threshold, [1.])
    FPR_x = numpy.zeros(len(threshold))
    TPR_y = numpy.zeros(len(threshold))
    for n in range(len(threshold)):
        y_pred, FPR, TPR = predict(X_test, y_true, scores, threshold[n])
        FPR_x[n] = FPR
        TPR_y[n] = TPR
        print(threshold[n])

    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.plot(FPR_x, TPR_y)
    plt.scatter(FPR_x, TPR_y, marker='o', color='blue')
    plt.show()

def predict(X, y_true, y_scores, threshold):
    n = X.shape[0]

    y_predict = numpy.zeros(n)
    P = sum(y_true)
    N = n - P

    for i, score in enumerate(y_scores):
        if score > threshold:
            y_predict[i] = 1
        else:
            y_predict[i] = 0
    FP, TP, FN, TN = confusion_matrix_values(y_predict, y_true)
    FPR = FP/N
    TPR = TP/P
    return y_predict, FPR, TPR

def confusion_matrix_values(y_predict, y_true):
    FP = .0
    TP = .0
    FN = .0
    TN = .0
    for i in range(len(y_predict)):
        if abs(y_predict[i]-y_true[i]) < 1e-6:# == True part
            if y_true[i] >= 1:
                TP+=1
            else:
                TN+=1
        else:# False part
            if y_true[i] <= 0:
                FP+=1
            else:
                FN+=1
    assert(FP+TP+FN+TN == len(y_predict))
    return FP, TP, FN, TN 

Below is the whole code:

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import numpy
import pandas
import sklearn.metrics
import sklearn.model_selection
import sklearn.linear_model
import sklearn.preprocessing
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc


def load_train_test_data(train_ratio=.8):
    data = pandas.read_csv('./HTRU2/HTRU_2.csv', header=None)

    X = data.iloc[:,:8]
    X = numpy.concatenate((numpy.ones((len(X), 1)), X), axis=1)
    y = data.iloc[:,8]
    y = numpy.array(y)
    print('the persentage of 1:',sum(y)/len(y))
    return sklearn.model_selection.train_test_split(X, y, test_size = 1 - train_ratio, random_state=0)


def scale_features(X_train, X_test, low=0, upp=1):
    minmax_scaler = sklearn.preprocessing.MinMaxScaler(feature_range=(low, upp)).fit(numpy.vstack((X_train, X_test)))
    X_train_scale = minmax_scaler.transform(X_train)
    X_test_scale = minmax_scaler.transform(X_test)
    return X_train_scale, X_test_scale

# logreg_sgd with L1 regularization
def logreg_sgd(X, y, alpha = .001, iters = 100000, eps=1e-4):

    n, d = X.shape
    theta = numpy.zeros((d, 1))
    k = 0
    lam = 0.001
    not_converge = True
    y_pdt = numpy.zeros(iters)
    y_true_iters = numpy.zeros(iters)
    y_delta = numpy.zeros(iters)

    for k in range(iters):
        if not not_converge:
            break
        i = k%n
        x = X[i, :]
        xT = numpy.transpose([x])
        y_hat = sigmoid(x, theta)
        beta = de_norm1(theta)

        func_g = (y_hat - y[i])*xT + lam*beta
        theta_k = theta.copy()
        theta = theta - alpha*func_g
        not_converge = True
        for delta in abs(theta-theta_k):
            if delta > eps:
                converge = False
                break

    return theta

def de_norm1(theta):
    d, _ = theta.shape
    beta = numpy.zeros((d, 1))

    for i in range(d):
        if theta[i,0] < 0:
            beta[i,0] = -1
        elif theta[i,0] > 0:
            beta[i,0] = 1
    return beta

def sigmoid(X, theta):
    z = numpy.dot(X, theta)
    value = 1.0/(1.0 + numpy.exp(-z))
    return value

def predict(X, y_true, y_scores, threshold):
    n = X.shape[0]

    y_predict = numpy.zeros(n)
    P = sum(y_true)
    N = n - P

    for i, score in enumerate(y_scores):
        if score > threshold:
            y_predict[i] = 1
        else:
            y_predict[i] = 0
    FP, TP, FN, TN = confusion_matrix_values(y_predict, y_true)
    FPR = FP/N
    TPR = TP/P
    return y_predict, FPR, TPR

def confusion_matrix_values(y_predict, y_true):
    FP = .0
    TP = .0
    FN = .0
    TN = .0
    for i in range(len(y_predict)):
        if abs(y_predict[i]-y_true[i]) < 1e-6:# == True part
            if y_true[i] >= 1:
                TP+=1
            else:
                TN+=1
        else:# False part
            if y_true[i] <= 0:
                FP+=1
            else:
                FN+=1
    assert(FP+TP+FN+TN == len(y_predict))
    return FP, TP, FN, TN 

# plot the ROC curve of your prediction
# x aixes: TPR = TP / ( TP + FN )
# y aixes: FPR = FP / ( FP + TN ) 
def plot_roc_curve(X_test, y_true, theta):
    scores = numpy.squeeze(sigmoid(X_test, theta))
    k = 10
    threshold = numpy.arange(0, 1, 1./k)
    threshold = numpy.append(threshold, [1.])
    FPR_x = numpy.zeros(len(threshold))
    TPR_y = numpy.zeros(len(threshold))
    for n in range(len(threshold)):
        y_pred, FPR, TPR = predict(X_test, y_true, scores, threshold[n])
        FPR_x[n] = FPR
        TPR_y[n] = TPR
        print(threshold[n])

    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.plot(FPR_x, TPR_y)
    plt.scatter(FPR_x, TPR_y, marker='o', color='blue')
    plt.show()

def main(argv):
    X_train, X_test, y_train, y_test = load_train_test_data(train_ratio=.8)
    X_train_scale, X_test_scale = scale_features(X_train, X_test, 0, 1)

    theta = logreg_sgd(X_train_scale, y_train)

    scores = sigmoid(X_test, theta)
    fpr, tpr, thresholds = roc_curve(y_test, scores)
    plt.plot(fpr, tpr)
    plt.show()
    plot_roc_curve(X_train, y_train, theta)


if __name__ == "__main__":
    main(sys.argv)

sample data:

140.5625,55.68378214,-0.234571412,-0.699648398,3.199832776,19.11042633,7.975531794,74.24222492,0
102.5078125,58.88243001,0.465318154,-0.515087909,1.677257525,14.86014572,10.57648674,127.3935796,0
103.015625,39.34164944,0.323328365,1.051164429,3.121237458,21.74466875,7.735822015,63.17190911,0
136.75,57.17844874,-0.068414638,-0.636238369,3.642976589,20.9592803,6.89649891,53.59366067,0
88.7265625,40.67222541,0.600866079,1.123491692,1.178929766,11.4687196,14.26957284,252.5673058,0
93.5703125,46.69811352,0.53190485,0.416721117,1.636287625,14.54507425,10.6217484,131.3940043,0
119.484375,48.76505927,0.03146022,-0.112167573,0.99916388,9.279612239,19.20623018,479.7565669,0
130.3828125,39.84405561,-0.158322759,0.389540448,1.220735786,14.37894124,13.53945602,198.2364565,0
107.25,52.62707834,0.452688025,0.170347382,2.331939799,14.48685311,9.001004441,107.9725056,0
107.2578125,39.49648839,0.465881961,1.162877124,4.079431438,24.98041798,7.397079948,57.78473789,0
142.078125,45.28807262,-0.320328426,0.283952506,5.376254181,29.00989748,6.076265849,37.83139335,0
133.2578125,44.05824378,-0.081059862,0.115361506,1.632107023,12.00780568,11.97206663,195.5434476,0
134.9609375,49.55432662,-0.135303833,-0.080469602,10.69648829,41.34204361,3.893934139,14.13120625,0
117.9453125,45.50657724,0.325437564,0.661459458,2.836120401,23.11834971,8.943211912,82.47559187,0
138.1796875,51.5244835,-0.031852329,0.046797173,6.330267559,31.57634673,5.155939859,26.14331017,0
114.3671875,51.94571552,-0.094498904,-0.287984087,2.738294314,17.19189079,9.050612454,96.61190318,0
109.640625,49.01765217,0.13763583,-0.256699775,1.508361204,12.07290134,13.36792556,223.4384192,0
100.8515625,51.74352161,0.393836792,-0.011240741,2.841137124,21.63577754,8.302241891,71.58436903,0
136.09375,51.69100464,-0.045908926,-0.271816393,9.342809365,38.09639955,4.345438138,18.67364854,0
99.3671875,41.57220208,1.547196967,4.154106043,27.55518395,61.71901588,2.20880796,3.662680136,1
100.890625,51.89039446,0.627486528,-0.026497802,3.883779264,23.04526673,6.953167635,52.27944038,0
105.4453125,41.13996851,0.142653801,0.320419676,3.551839465,20.75501684,7.739552295,68.51977061,0
95.8671875,42.05992212,0.326386917,0.803501794,1.83277592,12.24896949,11.249331,177.2307712,0
117.3671875,53.90861351,0.257953441,-0.405049077,6.018394649,24.76612335,4.807783224,25.52261561,0
106.6484375,56.36718209,0.378355072,-0.266371607,2.43645485,18.40537062,9.378659682,96.86022536,0
112.71875,50.3012701,0.279390953,-0.129010712,8.281772575,37.81001224,4.691826852,21.27620977,0
130.8515625,52.43285734,0.142596727,0.018885442,2.64632107,15.65443599,9.464164025,115.6731586,0
119.4375,52.87481531,-0.002549267,-0.460360287,2.365384615,16.49803188,9.008351898,94.75565692,0
123.2109375,51.07801208,0.179376819,-0.17728516,2.107023411,16.92177312,10.08033334,112.5585913,0
102.6171875,49.69235371,0.230438984,0.193325371,1.489130435,16.00441146,12.64653474,171.8329021,0
110.109375,41.31816988,0.094860398,0.68311261,1.010033445,13.02627521,14.66651082,231.2041363,0
99.9140625,43.91949797,0.475728501,0.781486196,0.619565217,9.440975862,20.1066391,475.680218,0
128.34375,52.17210664,-0.049280401,-0.208256987,2.173913043,12.9939472,9.965757364,141.5100843,0
142.0546875,53.87315957,-0.470772686,-0.125946417,4.423076923,27.08351266,6.681658306,45.94403008,0
121.1328125,47.6326062,0.177360308,0.024918111,2.151337793,20.55243738,9.920468181,99.74707919,0
102.328125,48.98040255,0.315729409,-0.202183315,1.898829431,13.83904002,11.61993869,172.1303732,0
147.8359375,53.62263651,-0.131079596,-0.288851172,2.692307692,17.08088101,8.849177975,92.20174502,0
108.0390625,34.91024257,0.321156562,1.821631493,3.899665552,23.72205203,7.506209958,60.88691267,0
107.875,37.33065932,0.49600476,1.481815856,1.173913043,12.01691346,14.53428973,252.6947381,0
118.84375,45.9319193,-0.109242666,0.137683548,2.33277592,14.71602871,9.634175054,118.6696797,0
138.4609375,48.91716569,-0.039591916,-0.176243068,2.443143813,18.3133067,8.672894053,83.06924213,0
116.203125,47.34586165,0.211946824,-0.022177703,3.606187291,18.94498977,7.035644684,59.23122572,0
120.5546875,45.54990543,0.282923998,0.419908714,1.358695652,13.07903424,13.31214143,212.5970294,1
121.8828125,53.04267461,0.200520721,-0.282219034,2.116220736,16.58087621,8.947602793,91.01176155,0
125.2109375,51.17519729,0.139851288,-0.385736754,1.147993311,12.41401211,14.06879728,228.1315536,0
107.90625,48.08414459,0.460846577,0.29651005,1.993311037,13.84106954,9.969395408,128.7447168,0
106.28125,43.02178545,0.408868006,1.032014666,1.610367893,17.25115554,12.11019331,152.0149562,0
106.3359375,45.05002035,0.418645099,0.603995884,1.200668896,12.38856143,13.30645184,209.41199,0
125.734375,52.65772207,0.026516673,-0.429632907,4.850334448,29.93619483,6.361837308,40.25501275,0
113.546875,49.50029346,0.130001201,-0.202839025,2.407190635,14.42205142,9.310343318,113.6874714,0
134.0390625,51.80045885,-0.195844789,-0.396816077,1.107859532,13.23858397,13.77580037,208.4202575,0
105.1171875,45.09202762,0.464847891,0.878058377,4.283444816,23.96731526,6.562543005,46.66728734,0
95.328125,44.66789069,0.386495074,0.755115427,2.694816054,17.9985973,9.094177089,97.80243629,0
119.3359375,47.506953,0.220316758,0.645717725,0.79264214,9.540907141,18.76653977,441.5133427,0
136.1875,51.95291588,-0.070431774,-0.482219687,0.849498328,9.677531027,18.73655411,431.3904454,0
112.859375,55.10625168,0.174766173,-0.404019163,3.032608696,19.69431374,7.266252257,58.03777067,0
108.625,52.74614915,0.453556415,0.069731528,2.304347826,16.18365586,9.780440566,114.9993838,0
113.953125,49.2214161,0.234723211,0.289792216,1.081103679,13.48209307,14.25608113,216.8362204,0
141.96875,50.47089779,0.244974491,-0.342664657,2.823578595,16.23818776,8.207743613,85.53258352,0
136.5,49.9327673,0.044623267,-0.374311456,1.555183946,12.81353792,13.31433912,214.813089,0
83.6796875,36.37928102,0.572531753,2.66461052,4.0409699,23.16912864,7.006681423,53.51400467,0
27.765625,28.66604164,5.770087392,37.4190088,73.11287625,62.07021971,1.268206006,1.082920221,1
135.859375,51.93727202,0.065768774,-0.366114187,20.77424749,52.77264803,2.730908619,6.607439551,0
112.09375,48.81156969,0.418565459,0.350156301,2.204013378,17.37868175,9.520551079,100.7875964,0
126.8671875,53.1293191,0.13633915,-0.588709439,1.149665552,13.96514443,13.23049959,186.2685104,0
117.5390625,47.73296528,0.173139263,-0.150653604,1.060200669,14.28934355,14.17637248,208.2780851,0
143.0859375,49.92197464,-0.157561213,-0.153332697,3.563545151,21.28808157,7.337117054,59.16844081,0
101.296875,39.43395574,0.390053688,1.551969375,4.925585284,26.32242163,6.086053659,39.11620774,0
119.8984375,53.82550508,0.143378486,-0.528427658,4.04180602,24.57913147,6.581293412,44.89951492,0
123.125,50.33124651,-0.087091427,0.087932382,1.280936455,10.68864639,14.63669101,288.668932,0
102.046875,48.79050551,0.45222638,0.272447732,2.37541806,13.9284014,9.127499454,116.0232222,0
119.4453125,53.14305702,0.012830273,-0.378955989,2.932274247,17.9297569,8.289888515,81.34651657,0
128.515625,54.94585181,-0.012552759,-0.658278628,2.891304348,17.75294666,8.913745414,94.08210337,0
128.15625,46.89690113,-0.179233074,-0.005819915,4.193979933,22.25815766,6.451755484,46.48663173,0
115.6171875,40.29037592,0.110702345,0.513224267,11.63963211,39.95655753,3.640288988,12.68457562,0
136.7421875,44.39123754,-0.22192524,0.908084632,2.105351171,14.49837742,10.13157115,128.3951486,0
135.265625,48.14390609,0.015920939,-0.15877212,8.539297659,31.13487695,4.082788387,17.27267344,0
113.9609375,52.24736871,0.127976811,-0.457499415,4.407190635,26.29776588,6.709564866,47.4057088,0
107.796875,45.6803362,0.655279783,0.954879021,1.7090301,15.1907807,11.52025038,150.3053634,0
124.5,57.35361802,-0.014849043,-0.550963937,4.783444816,27.50164045,6.090448645,37.81809112,0
119.296875,46.45417086,0.202629139,0.12837064,3.748327759,18.8510099,6.414682286,50.85055687,0
148.3828125,51.200757,-0.113195798,-0.50223559,1.408026756,12.08791939,12.5121354,201.1278905,0
109.4921875,53.2901838,0.2528458,-0.319022964,4.132943144,25.89210734,6.741542034,46.83080307,0
112.125,46.30840906,0.721646098,0.612454163,1.173076923,11.04918969,14.6307442,273.2509626,0
128.7734375,45.80669555,0.086169154,-0.031764808,2.66722408,15.93295829,8.75667197,95.36727143,0
140.265625,48.93721813,0.03252958,0.119064502,2.315217391,19.87317992,9.67260138,98.89698457,0
87.515625,51.76343189,1.070588903,0.74283956,15.67809365,50.90591579,3.141187931,8.440045483,0
132.140625,42.09582342,0.143191723,0.876730035,1.863712375,13.26595667,10.25798651,140.0407088,0
104.078125,45.24078107,0.532040422,0.743853067,1.43645485,15.41478275,11.89911604,150.9872549,0
122.6015625,53.79697654,-0.051964773,-0.379729027,2.636287625,15.17095406,9.519292364,117.7422254,0
114.28125,41.25396525,0.41182113,0.616996141,2.412207358,20.42794216,9.198391753,88.37057957,0
112.4375,38.2956733,0.501943444,1.07484029,2.81270903,18.13688307,7.859968426,71.29944944,0
23.625,29.94865398,5.688038235,35.98717152,146.5685619,82.39462399,-0.274901598,-1.121848281,1
94.5859375,35.77982308,1.187308683,3.68746932,6.071070234,29.76039993,5.318766827,28.69804799,1
137.2421875,46.45474042,0.045257133,-0.438857507,59.4958194,77.75535652,0.71974817,-1.183162032,0
123.53125,53.34878418,0.072077648,-0.071600995,0.781772575,10.57083301,17.11829958,339.6608262,0
70.0234375,35.28067478,1.157657193,4.546692371,3.003344482,19.57538355,7.954436097,71.96015886,0
129.375,44.56841651,0.049779493,0.506330188,3.60451505,21.13303805,7.181384025,56.85662961,0
97.140625,47.77089438,0.625218075,0.740796144,4.193143813,26.46526062,6.927045631,49.62852693,0
101.96875,46.31632702,0.439814307,0.294261355,1.748327759,16.4866229,10.8103928,127.7333664,0
Vivek Kumar
  • 35,217
  • 8
  • 109
  • 132
song
  • 31
  • 1
  • 4

1 Answers1

1

The threshold is the problem. When plotting roc curve, the threshold should be the scores which get from the logistic function. The code get the threshold by splitting several part in range 0 to 1 equally.

By the trick, we can sort the scores in decent order and change the FP and TP by the value of the target.

Revised code:

def plot_roc_curve(X_test, y_true, theta):
    scores = numpy.squeeze(sigmoid(X_test, theta))
    scores_decent_ord = numpy.argsort(scores)[::-1]
    scores = scores[scores_decent_ord]
    y_true = y_true[scores_decent_ord]

    FPR_x = numpy.zeros(len(y_true)+1)
    TPR_y = numpy.zeros(len(y_true)+1)
    P = sum(y_true)
    N = len(y_true) - P
    FP = .0
    TP = .0
    for i, target in enumerate(y_true):
        if target >= 1:
            TP+=1
        else:
            FP+=1
        FPR_x[i] = FP/N
        TPR_y[i] = TP/P

    FPR_x[-1] = .0
    TPR_y[-1] = .0
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.plot(FPR_x, TPR_y)
    for i in range(len(FPR_x)):
        print(FPR_x[i], ',', TPR_y[i])
    # plt.scatter(FPR_x, TPR_y, marker='o', color='blue')
    plt.show()
song
  • 31
  • 1
  • 4