How can I use yellowbrick on the output of non-Scikit models?
I have a PyTorch multi-class classifier network and would like to use the ClassificationReport functionality on the results of applying this model to data. How can I do this?
How can I use yellowbrick on the output of non-Scikit models?
I have a PyTorch multi-class classifier network and would like to use the ClassificationReport functionality on the results of applying this model to data. How can I do this?
If you use the skorch
library which makes Pytorch models sci-kit learn compatible then you can use yellowbrick's Third party wrappers then you can possibly make your models work. Here is some example code
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
from sklearn.model_selection import train_test_split
from skorch import NeuralNetClassifier
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=nn.ReLU()):
super(MyModule, self).__init__()
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, num_units)
self.output = nn.Linear(num_units, 2)
self.softmax = nn.Softmax(dim=-1)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = self.nonlin(self.dense1(X))
X = self.softmax(self.output(X))
return X
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
lr=0.1,
# Shuffle training data on each epoch
iterator_train__shuffle=True,
)
# Import the wrap function and a Yellowbrick visualizer
from yellowbrick.contrib.wrapper import wrap
from yellowbrick.classifier import classification_report
# Instantiate the third party estimator and wrap it, optionally fitting it
model = wrap(net)
model.fit(X_train, y_train)
# Use the visualizer
oz = classification_report(model, X_train, y_train, X_test=X_test, y_test=y_test, support=True, is_fitted=True)