5

I am trying to incorporate PyTorch functionalities into a scikit-learn environment (in particular Pipelines and GridSearchCV) and therefore have been looking into skorch. The standard documentation example for neural networks looks like

import torch.nn.functional as F
from torch import nn
from skorch import NeuralNetClassifier

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        ...
        ...
        self.output = nn.Linear(10, 2)
    ...
    ...

where you explicitly pass the input and output dimensions by hardcoding them into the constructor. However, this is not really how scikit-learn interfaces work, where the input and output dimensions are derived by the fit method rather than being explicitly passed to the constructors. As a practical example consider

# copied from the documentation
net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

# any general Pipeline interface
pipeline = Pipeline([
        ('transformation', AnyTransformer()),
        ('net', net)
        ])

gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')
gs.fit(X, y)

besides the fact that nowhere in the transformers must one specify the input and output dimensions, the transformers that are applied before the model may change the dimentionality of the training set (think at dimensionality reductions and similar), therefore hardcoding input and output in the neural network constructor just will not do.

Did I misunderstand how this is supposed to work or otherwise what would be a suggested solution (I was thinking of specifying the constructors into the forward method where you do have X available for fit already, but I am not sure this is good practice)?

gented
  • 1,620
  • 1
  • 16
  • 20
  • Excellent point. Just to point out that the sklearn wrapper for tensorflow, `SciKeras`, makes setting input size dynamically a breeze, using the `meta["n_features_in_"]` parameter: see the [docs](https://adriangb.com/scikeras/stable/notebooks/MLPClassifier_MLPRegressor.html#2.1-Inputs). Output size too is easy using `meta["n_classes_"]`: see [docs](https://adriangb.com/scikeras/stable/notebooks/MLPClassifier_MLPRegressor.html#2.3-Output-layers). If `skorch` could introduce similar functionality that would make usage much easier! – gnoodle Aug 06 '23 at 14:32

1 Answers1

4

This is a very good question and I'm afraid that there is best practice answer to this as PyTorch is normally written in a way where initialization and execution are separate steps which is exactly what you don't want in this case.

There are several ways forward which are all going in the same direction, namely introspecting the input data and re-initializing the network before fitting. The simplest way I can think of is writing a callback that sets the corresponding parameters during training begin:

class InputShapeSetter(skorch.callbacks.Callback):
    def on_train_begin(self, net, X, y):
        net.set_params(module__input_dim=X.shape[-1])

This sets a module parameter during training begin which will re-initialize the PyTorch module with said parameter. This specific callback expects that the parameter for the first layer is called input_dim but you can change this if you want.

A full example:

import torch
import skorch
from sklearn.datasets import make_classification
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA

X, y = make_classification()
X = X.astype('float32')

class ClassifierModule(torch.nn.Module):
    def __init__(self, input_dim=80):
        super().__init__()
        self.l0 = torch.nn.Linear(input_dim, 10)
        self.l1 = torch.nn.Linear(10, 2)

    def forward(self, X):
        y = self.l0(X)
        y = self.l1(y)
        return torch.softmax(y, dim=-1)


class InputShapeSetter(skorch.callbacks.Callback):
    def on_train_begin(self, net, X, y):
        net.set_params(module__input_dim=X.shape[-1])


net = skorch.NeuralNetClassifier(
    ClassifierModule,
    callbacks=[InputShapeSetter()],
)

pipe = Pipeline([
    ('pca', PCA(n_components=10)),
    ('net', net),
])

pipe.fit(X, y)
print(pipe.predict(X))
nemo
  • 55,207
  • 13
  • 135
  • 135
  • This works fine, it's a nice solution, thank you! However, a further limitation of `skorch` is apparently that the callbacks only work when the input is a `np.ndarray` (not for example a `pd.DataFrame/series` or a `list`). I guess one has to write additional callbacks to turn objects into the preferred format...but at this point one might as well just write a whole new estimator from scratch :) – gented Feb 12 '20 at 13:06
  • You can easily pass `pd.DataFrame` as `X` and the above example will work just fine. If you experience problems, feel free to ask a new question or open a new skorch issue. – nemo Feb 13 '20 at 09:17
  • I have passed a dataframe and I am getting exceptions (due to the presence of headers) in a pipeline; I am investigating this myself meanwhile :). – gented Feb 13 '20 at 10:27
  • OK! Note that when passing a data frame (or dictionary) the columns/entries are passed as named parameters to the `forward` method of your module, so you need to name these parameters accordingly. – nemo Feb 14 '20 at 09:30
  • Thanks @nemo for the workaround! Unfortunately it has a side effect in that every time .`fit()` is run, it prints ```Re-initializing module because the following parameters were re-set: ... Re-initializing criterion. Re-initializing optimizer.``` - doesn't appear to be a warning or info log. Is there any way this can be suppressed? – gnoodle Aug 06 '23 at 15:48
  • Just to add for others - you can disable the `Re-initializing` lines by using `verbose=False` in the NN. So in the example given, you can use `net = skorch.NeuralNetClassifier( ClassifierModule, callbacks=[InputShapeSetter()], verbose=False )`. However that will also mean the training progress for each epoch won't be displayed (but can still be accessed through the `history` attribute) – gnoodle Aug 06 '23 at 15:57