10

I m trying to use the decision tree classified to identify two classes (renamed 0 and 1) based on certain parameters. I train it using a dataset and then run it on the "test dataset". When I try to calculate the probability for each data point in the test dataset, it returns 0 or 1, only. I wonder what is the problem.

Here is the sample code :

clf=tree.DecisionTreeClassifier(random_state=0) trained=clf.fit(data,identifier) # training data where identifier is 0 or 1 predict=trained.predict(test_data) The results from this are :

In [9]: predict

Out[9]: 
array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
       1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,
       0, 0, 1, 0, 0, 0])

In [10]: trained.predict_proba(test_data)[:,1]

Out[10]: 
array([ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,
        0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  1.,  0.,
        1.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,
        0.,  1.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.])

I would like to generate and ROC, which at this point just returns 3 data points for FPR/TPR.

Here is the complete data set : Identifier is the last column of "data".

Train data :

Spectral_Index,W1-W2,W2-W3,HR0.3-100,HR50-2,Gamma,Class
1.4304664,0.61,2.18,0.3819051,0.99992716,1.93,0
1.6969398,0.54,1.93,0.66479063,0.9999814,2.11,0
2.233997,1.02,3.18,0.55532146,0.9999979,2.07,0
2.230639,0.77,2.34,0.0012237767,1.0,1.81,0
1.7325432,0.71,2.27,0.34395835,1.0,1.9,0
1.8728518,0.8,2.14,0.4255796,1.0,1.96,0
1.9818852,0.7,2.18,-0.08978904,1.0,1.66,0
2.3864453,0.95,2.51,0.109010585,0.98401743,1.81,0
2.5911317,0.94,2.49,0.60381645,0.99991965,2.03,0
1.9564596,0.81,2.29,0.3843,0.9999495,2.08,0
2.1506176,0.93,2.62,0.28551856,0.9999999,1.91,0
1.9069784,0.62,1.76,0.041608978,1.0,1.86,0
1.6216202,0.77,2.11,-0.14271076,1.0,1.7,0
2.276335,0.68,2.14,0.40399882,1.0,2.06,0
2.2430172,1.0,2.94,0.61844856,1.0,2.12,0
1.0226197,0.66,2.07,-0.14886126,1.0,1.84,0
2.2564504,1.06,2.77,0.6974536,0.99844635,2.16,0
2.2819016,0.88,2.37,0.30696234,0.999996,1.86,0
1.4881139,0.7,2.09,0.40853307,1.0,1.82,0
2.4640048,0.9,2.39,0.35103577,1.0,2.02,0
2.656071,0.72,2.29,0.21568911,0.9999046,2.11,0
1.7204628,0.62,2.01,0.19794853,1.0,1.8,0
1.9134961,0.86,2.27,0.37281907,1.0,1.94,0
1.3061943,0.67,2.01,0.3463318,0.99999976,1.86,0
1.8845558,0.64,2.01,0.12364135,0.9999834,1.84,0
2.4409518,1.12,3.31,0.7502838,1.0,2.17,0
1.9501582,0.85,2.34,0.29961613,0.9999974,1.92,0
2.1314192,1.03,2.62,0.69623667,1.0,2.28,0
1.7345899,0.69,2.61,0.38524705,0.99999887,2.09,0
1.7095753,0.75,2.08,0.21696341,0.9999987,1.95,0
1.9115254,0.83,2.17,-0.046689913,1.0,1.85,0
1.565369,0.67,2.01,-0.04827315,0.9999915,1.79,0
2.2971635,0.59,2.1,0.35741857,1.0,2.0,0
3.042759,1.06,2.94,0.70878696,0.9999844,2.15,0
2.340724,0.96,2.74,0.42822766,0.99999416,1.97,0
1.8552977,0.74,2.09,0.07262661,1.0,1.69,0
2.0324602,0.66,2.05,-0.07643526,0.9999982,1.83,0
1.8508979,0.67,1.96,0.054557554,0.99997455,1.75,0
2.7983437,0.96,2.58,0.8554537,0.9999992,2.2,0
2.1728642,1.09,3.05,0.61488354,1.0,2.04,0
3.113785,0.66,1.85,0.48011553,0.99995273,1.95,0
3.0665417,0.78,2.19,0.27814054,1.0,1.86,0
2.0060341,0.83,2.39,0.20785762,0.9999502,1.85,0
2.1786506,0.57,2.0,0.33096096,1.0,1.91,0
1.823961,0.72,1.96,-0.103285044,1.0,1.6,0
1.612012,0.68,2.15,-0.3136376,0.65517294,1.52,0
2.1615896,0.87,2.4,0.47535577,1.0,2.04,0
2.3053634,1.06,2.92,0.67040676,0.9991328,2.15,0
1.7525402,0.73,2.12,0.25563625,0.9999979,1.92,0
2.7306526,0.91,2.35,0.68943393,-0.4308276,2.1,0
2.2549937,1.07,2.91,0.6077795,0.9999626,2.04,0
2.0924683,0.69,2.04,-0.068183094,0.3497915,1.77,0
2.210627,0.84,2.09,0.6309954,0.99999976,1.99,0
2.4609168,0.67,2.08,0.29552716,0.99964327,1.96,0
2.5169518,0.84,2.45,0.35437247,0.9999745,1.92,0
2.1841373,0.9,2.51,0.5617463,1.0,2.15,0
3.0673068,0.8,2.22,0.17641401,1.0,1.9,0
2.6202004,0.97,2.47,0.36663872,1.0,2.03,0
1.9694642,0.95,2.54,0.33140072,0.99998665,2.04,0
1.8766946,0.84,2.32,-0.024992371,0.99999803,1.94,0
2.9352057,1.2,2.96,0.6385377,0.9951195,2.18,0
1.4075257,0.86,2.27,0.046303034,0.9999998,1.81,0
1.8769667,0.6,2.0,0.08842805,0.15410244,1.83,0
1.2585826,0.71,1.96,0.005930161,0.78259146,1.72,0
2.2046561,0.9,2.37,0.62021697,1.0,2.07,0
1.0217602,0.49,1.89,-0.26944694,0.9999997,1.66,0
2.1021683,1.05,2.78,0.5306551,1.0,2.14,0
2.4789429,0.94,2.52,0.34224525,0.9999965,2.01,0
2.1449182,0.8,2.32,0.37609425,0.9997282,2.25,0
2.7071185,0.83,2.36,0.75363404,1.0,2.31,0
1.8445525,1.04,2.76,0.6075378,0.88632137,2.14,0
1.6024263,1.09,2.63,0.64461184,1.0,2.18,0
2.0292685,0.53,2.15,0.090091705,1.0,1.92,0
2.0858748,0.71,1.86,0.14351326,0.9999994,1.88,0
2.1292083,0.81,2.31,0.33257455,1.0,1.95,0
1.6344122,0.84,2.38,0.6371139,0.9999998,2.11,0
1.7532507,0.75,2.04,0.16182575,1.0,1.78,0
2.2479355,0.97,2.72,0.41953298,1.0,2.04,0
2.5790315,1.07,2.96,0.7216893,0.9999953,2.11,0
3.0039942,1.03,2.44,0.8042694,0.9998856,2.25,1
3.7599833,1.16,3.23,0.9095345,0.66683024,2.39,1
2.8912013,1.05,2.67,0.85215354,0.9967052,2.27,1
3.8784094,1.11,3.18,0.6971026,1.0,2.19,1
2.1862392,1.13,2.7,0.65855825,1.0,2.28,1
2.7684402,1.16,2.79,0.9261603,-0.9540385,2.35,1
1.7551649,0.56,2.18,0.23092282,1.0,1.98,1
2.804592,1.13,2.98,0.84827685,1.0,2.3,1
1.9874831,1.0,2.98,0.87599415,1.0,2.21,1
2.5059428,1.16,2.79,0.97649753,0.9997586,2.42,1
2.812127,1.12,3.11,0.87392867,1.0,2.21,1
2.9445121,1.06,3.17,0.8849491,1.0,2.41,1
2.7388847,1.11,2.78,0.84986275,0.96669436,2.32,1
2.1416433,1.1,3.61,0.7671358,0.9999998,2.29,1
2.3661094,1.05,3.16,0.73194104,0.99990827,2.14,1
2.761189,1.09,2.81,0.7681978,-0.99955946,2.23,1
2.6658804,1.02,3.36,0.8036201,0.98403203,2.28,1
2.720667,0.99,2.78,0.97055733,0.9781505,2.48,1
2.6812658,0.98,3.05,0.73290765,1.0,2.09,1
1.4784714,0.62,1.97,0.418,1.0,2.02,0
1.7488811,0.7,2.05,0.418,0.99999624,2.02,0

test data :

Spectral_Index,W1-W2,W2-W3,HR0.3-100,HR50-2,Gamma
1.6724254,0.95,2.58,0.92031854,1.0,2.15
2.552926,0.93,2.74,0.63588345,-0.30092865,2.18
2.5737462,0.86,2.22,0.43023747,1.0,2.08
2.1701677,0.62,2.19,0.6892167,1.0,2.15
3.6152358,0.96,2.58,0.67760235,0.99704355,2.06
3.6193092,0.82,2.34,0.4083981,0.9973078,2.04
2.0209844,1.02,2.86,0.8595182,-0.9979041,2.36
2.166221,1.07,3.0,0.7177616,-0.99961376,2.3
2.7933478,0.94,2.4,0.678935,1.0,2.12
2.2969048,0.86,2.29,0.18689133,1.0,1.96
3.1255674,1.15,2.77,0.9290483,0.6387009,2.28
2.3548958,1.01,2.46,0.75331503,-1.0,2.21
3.9791226,1.15,3.04,0.87006325,-0.99919724,2.43
2.3430493,0.85,2.42,0.81132597,-0.9999996,2.04
3.7431624,0.79,2.57,0.704,0.99952716,2.20784
3.1846259,1.14,2.85,0.9104803,0.99891067,2.3
3.1416001,0.73,2.26,0.5679769,1.0,1.98
2.670179,0.85,2.66,0.7376513,0.97939825,2.1
3.010911,0.79,2.38,0.21750104,0.21187924,1.82
1.4430648,0.9,2.38,0.7361963,0.999758,2.11
2.8149416,1.07,2.62,0.94750744,0.9967568,2.4
3.8395922,1.09,2.91,0.27485812,0.99887043,2.05
3.1686394,0.66,2.11,0.529385,1.0,1.9
3.190167,1.09,3.1,0.8501991,0.9507157,2.23
3.8597586,1.13,3.64,0.89043206,0.17880388,2.42
2.1516426,0.85,2.24,0.6673518,0.9985168,2.2
2.1318088,0.98,2.64,0.85542095,1.0,2.22
1.6740437,0.97,2.99,0.86632746,0.9983954,2.41
4.273427,1.01,2.71,0.8941501,0.64256436,2.47
2.284782,0.92,2.7,0.5820462,0.6981752,2.1
3.343603,1.06,2.84,0.6901738,0.83269715,2.13
5.766362,1.2,3.74,0.99009913,0.99998844,2.49
2.1547525,0.95,3.02,0.75229234,0.99604213,2.57
2.9853358,0.91,2.37,0.62881154,-0.98792726,2.06
2.8614197,0.82,2.15,0.75643075,1.0,2.19
3.6815813,1.14,3.24,0.8886577,-0.030438267,2.39
4.539201,1.17,2.83,0.93989134,0.23378997,2.55
3.35261,1.1,2.73,0.9184936,0.9998006,2.41
3.6697345,1.16,3.57,0.9515105,0.9999988,2.43
1.9781204,0.91,2.85,-0.06649571,0.9999991,1.7
2.6618617,1.1,3.24,0.8348949,-0.9834342,2.29
3.8140056,1.18,3.25,0.8766021,1.0,2.39
2.1926181,1.05,2.3,0.6880097,1.0,2.3
2.0248337,0.83,2.29,0.3604591,0.46159065,2.05
3.904931,1.13,2.46,0.9100119,1.0,2.32
1.9945884,0.94,2.5,0.4632657,0.9869119,2.05
3.3342967,1.1,3.04,0.51323855,-0.5262294,2.23
2.3138714,0.91,2.36,0.90414697,0.9999977,2.29
2.3118904,1.04,3.01,0.87289846,0.998577,2.29
2.246307,1.07,2.72,0.6147379,0.9999993,2.11
1.6369493,0.89,2.34,0.61421084,0.9997295,2.22
3.6198807,0.93,2.62,0.7463702,0.9994778,2.07
Vivek Kumar
  • 35,217
  • 8
  • 109
  • 132
akaur
  • 389
  • 1
  • 6
  • 22
  • 3
    It depends on the data. Maybe your data is such that the decisiontree is able to successfully separate the classes (maybe the tree is overfitting) and hence the tree knows for sure that the particular instance belongs to class1 with 100% probability. Without seeing the data samples, we cant do much. – Vivek Kumar Jan 12 '18 at 05:50
  • I am using 100 points for training and 6 parameters.. Do you think this is causing the problem ? – akaur Jan 12 '18 at 15:46
  • 1
    Have a look at https://martin-thoma.com/comparing-classifiers/ (code is there). Try other classifiers, especially logistic regression and k nearest neighbors. What is the result for them? – Martin Thoma Jan 15 '18 at 07:04
  • 1
    On the data you have given, the tree is giving 100% accuracy on train. So its completely fitting the data. – Vivek Kumar Jan 15 '18 at 07:09

1 Answers1

11

There is no problem - the tree behaves exactly as expected.

A decision tree computes the class probability from the number of samples of each class that fall into a given leaf.

The documentation says:

The default values for the parameters controlling the size of the trees (e.g. max_depth, min_samples_leaf, etc.) lead to fully grown and unpruned trees

I.e. the tree is grown until it perfectly (over)fits the training data. This means that all training samples in each leaf are of the same class and a test sample either matches that class (p=1) or does not (p=0).

To get finer probability estimates you can restrict min_samples_leaf so that there is a minimum number of samples in each leaf, which will be used to compute probabilities (with one sample you get [0, 1] - with e.g. 10 samples you can get [0, 0.1, 0.2, ..., 0.9, 1]). You are going to have to experiment with the settings to find which numbers work best for you and your data.

MB-F
  • 22,770
  • 4
  • 61
  • 116
  • Thank you very much for clarifying it. Then the problem is an ROC curve, which will only show 3 points in this case, because of two probability values and one more at Infinity. How do I go about presenting my results from this classification ? – akaur Jan 16 '18 at 16:07
  • 5
    I would restrict `min_samples_leaf` so that there is a minimum number of samples in each leaf, which will be used to compute probabilities (with one sample you get [0, 1] - with e.g. 10 samples you can get [0, 0.1, 0.2, ..., 0.9, 1]). You are going to have to experiment with the settings to find which numbers work best for you and your data. – MB-F Jan 22 '18 at 07:22
  • I am using SGD and there is not such parameter in Sklearn implementation. For me seems changing the alpha parameter more or less normalize the predicted probabilities. – Federico Taranto Aug 24 '22 at 12:53
  • @FedericoTaranto The problem here is specific to decision trees. Yet you are right to control overfitting in stochastic gradient descent with the regularization parameter. – MB-F Aug 24 '22 at 15:02