18

I was checking sklearn documentation webpage about GridSearchCV. One of attributes of GridSearchCV object is best_estimator_. So here is my question. How to pass more than one estimator to GSCV object?

Using a dictionary like: {'SVC()':{'C':10, 'gamma':0.01}, ' DecTreeClass()':{....}}?

Vivek Kumar
  • 35,217
  • 8
  • 109
  • 132
mikinoqwert
  • 375
  • 3
  • 15

1 Answers1

27

GridSearchCV works on parameters. It will train multiple estimators (but same class (one of SVC, or DecisionTreeClassifier, or other classifiers) with different parameter combinations from specified in param_grid. best_estimator_ is the estimator which performs best on the data.

So essentially best_estimator_ is the same class object initialized with best found params.

So in the basic setup you cannot use multiple estimators in the grid-search.

But as a workaround, you can have multiple estimators when using a pipeline in which the estimator is a "parameter" which the GridSearchCV can set.

Something like this:

from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_iris
iris_data = load_iris()
X, y = iris_data.data, iris_data.target


# Just initialize the pipeline with any estimator you like    
pipe = Pipeline(steps=[('estimator', SVC())])

# Add a dict of estimator and estimator related parameters in this list
params_grid = [{
                'estimator':[SVC()],
                'estimator__C': [1, 10, 100, 1000],
                'estimator__gamma': [0.001, 0.0001],
                },
                {
                'estimator': [DecisionTreeClassifier()],
                'estimator__max_depth': [1,2,3,4,5],
                'estimator__max_features': [None, "auto", "sqrt", "log2"],
                },
               # {'estimator':[Any_other_estimator_you_want],
               #  'estimator__valid_param_of_your_estimator':[valid_values]

              ]

grid = GridSearchCV(pipe, params_grid)

You can add as many dicts inside the list of params_grid as you like, but make sure that each dict have compatible parameters related to the 'estimator'.

Vivek Kumar
  • 35,217
  • 8
  • 109
  • 132
  • 1
    I don't get this, since you're starting the Pipeline with only SVC, how is it possible to pass params for other classifiers? the API is a little confusing there :( – George Katsanos Dec 07 '22 at 14:03