1

I need to fix the value of a parameter of a scikit-learn estimator. I still need to be able to change all the other parameters of the estimator, and to use the estimator within scikit-learn tools such as Pipelines and GridSearchCV.

I tried to define a new class inheriting from a scikit-learn estimator. For instance, here I am trying to create a new class that fixes n_estimators=5 of a RandomForestClassifier.

class FiveTreesClassifier(RandomForestClassifier):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.n_estimators = 5


fivetrees = FiveTreesClassifier()
randomforest = RandomForestClassifier(n_estimators=5)

# This passes.
assert fivetrees.n_estimators == randomforest.n_estimators
# This fails: the params of fivetrees is an empty dict.
assert fivetrees.get_params() == randomforest.get_params()

The fact that get_params() is not reliable means that I cannot use the new estimator within Pipelines and GridSearchCV (as explained here).

I am using scikit-learn 0.24.2, but I think it would actually be the same with newer versions.

I would prefer answers that let me define a new class while fixing the value of an hyperparameter. I would also accept answers that use other techniques. I would also be thankful of thorough explanations of why I should / should not do this!

Enrico Gandini
  • 855
  • 5
  • 29

1 Answers1

1

You can use functools.partial

NewEstimator = partial(RandomForestClassifier, n_estimators=5)
new_estimator = NewEstimator()
Franco Piccolo
  • 6,845
  • 8
  • 34
  • 52
  • Can you give me some more details? I need to make sure that the class derived from RandomForestClassifier always has 'n_estimators=5'. The other parameters will be set freely for any new instance of the class, but n_estimators must always be 5. – Enrico Gandini Nov 25 '21 at 08:09
  • I wouldn't create a new class. Otherwise on initialization you need to pass all parameters, the reason you are getting an empty dict is because you are passing no parameters. I would just set n_estimators=5 every time. – Franco Piccolo Nov 25 '21 at 20:48
  • Yes, you are probably right, but I still need to be sure that it is not possible to instantiate a RandomForestClassifier with 'n_estimators!=5'. I am not the only developer / user of the code I am writing. How can I make sure that nobody ever instantiates a RandomForestClassifier with 'n_estimators!=5'? Do I have to place assertions and checks in many places in my code? Or is there a solution to make 'n_estimators=5' mandatory forever? Not necessarily a new class, maybe something else, maybe some kind of wrapper, I do not know, that's why I am asking! – Enrico Gandini Nov 25 '21 at 22:46
  • By the way: the RandomForestClassifier is just an example, I want to know in general how to fix a parameter for scikit-learn estimators. – Enrico Gandini Nov 25 '21 at 22:52
  • check my new answer – Franco Piccolo Nov 25 '21 at 23:19
  • 1
    Thank you, I have tested it and it works! I actually cannot use the answer, since I also needed to redefine the 'predict' method. But this other requirement was not specified in the question, so your answer is accepted. – Enrico Gandini Nov 26 '21 at 13:07