4

I am trying to generate a heatmap for the GridSearchCV results from sklearn. The thing I like about sklearn-evaluation is that it is really easy to generate the heatmap. However, I have hit one issue. When I give a parameter as None, for e.g.

max_depth = [3, 4, 5, 6, None]

while generating, the heatmap, it shows error saying:

TypeError: '<' not supported between instances of 'NoneType' and 'int'

Is there any workaround for this? I have found other ways to generate heatmap like using matplotlib and seaborn, but nothing gives as beautiful heatmaps as sklearn-evalutaion.

enter image description here

spockshr
  • 372
  • 2
  • 14

1 Answers1

5

I fiddled around with the grid_search.py file /lib/python3.8/site-packages/sklearn_evaluation/plot/grid_search.py. At line 192/193 change the lines

From

row_names = sorted(set([t[0] for t in matrix_elements.keys()]),
                   key=itemgetter(1))
col_names = sorted(set([t[1] for t in matrix_elements.keys()]),
                   key=itemgetter(1))

To:

row_names = sorted(set([t[0] for t in matrix_elements.keys()]),
                   key=lambda x: (x[1] is None, x[1]))
col_names = sorted(set([t[1] for t in matrix_elements.keys()]),
                   key=lambda x: (x[1] is None, x[1]))

Moving all None to the end of a list while sorting is based on a previous answer from Andrew Clarke.

Using this tweak, my demo script is shown below:

import numpy as np
import sklearn.datasets as datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn_evaluation import plot

data = datasets.make_classification(n_samples=200, n_features=10, n_informative=4, class_sep=0.5)


X = data[0]
y = data[1]

hyperparameters = {
    "max_depth": [1, 2, 3, None],
    "criterion": ["gini", "entropy"],
    "max_features": ["sqrt", "log2"],
}

est = RandomForestClassifier(n_estimators=5)
clf = GridSearchCV(est, hyperparameters, cv=3)
clf.fit(X, y)
plot.grid_search(clf.cv_results_, change=("max_depth", "criterion"), subset={"max_features": "sqrt"})


import matplotlib.pyplot as plt

plt.show()

The output is as shown below: enter image description here

lifezbeautiful
  • 1,160
  • 8
  • 13
  • 2
    This is a nice solution. For the time being, I am using seaborn heatmap which requires some manual data manipulation. It would be nice to see this solution merged in sklearn-evaluation. – spockshr Jun 26 '21 at 12:19