I'll post what I did eventually. I did a wrapper function that returns a child class of an estimator (eg LogisticRegression) with augmented predict_proba and augmented fit function. The fit function saves which labels it has seen in y_train. the predict_proba function fills with zero the columns corresponding to labels that were not present in y_train but present in labels.
def predict_proba_wrapper(method: classmethod, labels: list):
"""Add zeros to the predict_proba columns if labels not present in y_true."""
@wraps(method)
def wrapper(self, *args, **kwargs ):
# find labels indices not in y_train
indices_to_fill = []
for i, label in enumerate(labels):
if label not in self.labels_seen:
indices_to_fill.append(i)
# call method
y_pred = method(self, *args, **kwargs)
# fill zeros
if not isinstance(y_pred, np.ndarray):
y_pred_np = np.array(y_pred)
else:
y_pred_np = y_pred
for i in indices_to_fill:
y_pred_np = np.insert(y_pred_np, i, 0., axis=1)
if isinstance(y_pred, np.ndarray):
return y_pred_np
elif isinstance(y_pred, pd.DataFrame):
return pd.DataFrame(y_pred_np, index=y_pred.index)
elif isinstance(y_pred, pd.Series):
return pd.Series(y_pred_np, index=y_pred.index)
elif isinstance(y_pred, list):
return y_pred_np.tolist()
else:
raise ValueError(f"y_pred type {type(y_pred)} not supported")
return wrapper
def fit_wrapper(method: classmethod, labels: list):
"""Add labels seen to the class."""
@wraps(method)
def wrapper(self, *args, **kwargs ):
res = method(self, *args, **kwargs)
if len(args) >= 2:
y = args[1]
else:
y = kwargs["y"]
if isinstance(y, np.ndarray):
self.labels_seen = list(np.unique(y))
elif isinstance(y, pd.DataFrame):
self.labels_seen = list(y.iloc[:, 0].unique())
elif isinstance(y, pd.Series):
self.labels_seen = list(y.unique())
elif isinstance(y, list):
self.labels_seen = list(set(y))
else:
raise ValueError(f"y type {type(y)} not supported")
if hasattr(self, "classes_"):
if isinstance(self.classes_, np.ndarray):
self.classes_ = np.array(labels)
elif isinstance(self.classes_, pd.DataFrame):
self.classes_ = pd.DataFrame(labels)
elif isinstance(self.classes_, pd.Series):
self.classes_ = pd.Series(labels)
elif isinstance(self.classes_, list):
self.classes_ = labels
else:
raise ValueError(f"y type {type(y)} not supported")
return res
return wrapper
def class_child_with_wrapped_methods(class_: Type, method_names: List[str], wrappers: List[callable]):
"""Return a new class with a method wrapped by method wrapper."""
new_class = type(class_.__name__ + "Wrapped", (class_,), {})
for i, method_name in enumerate(method_names):
setattr(new_class, method_name, wrappers[i](getattr(new_class, method_name)))
return new_class
def wrap_fit_predict_proba(class_: Type, labels: list):
"""Return a new class with predict_proba wrapped by predict_proba_wrapper."""
return class_child_with_wrapped_methods(
class_,
["predict_proba", "fit"],
[
lambda x: predict_proba_wrapper(x, labels),
lambda x: fit_wrapper(x, labels)
]
)
CLASSIFIERS = [
wrap_fit_predict_proba(LogisticRegression, labels[-1,0,1]),
wrap_fit_predict_proba(ExtraTreesClassifier, labels=[-1,0,1]),
]