0

This question is pretty similar to this one and based on this post over GitHub, in the sense that I am trying to convert an SVM multiclass classification model (e.g., using sklearn) to a Keras model.

Specifically, I am looking for a way of retrieving probabilities (similar to SVC probability=True) or confidence value at the end so that I can define some sort of threshold and be able to distinguish between trained classes and non-trained ones. That is if I train my model with 3 or 4 classes, but then use a 5th that it wasn't trained with, it will still output some prediction, even if totally wrong. I want to avoid that in some way.

I got the following working reasonably well, but it relies on picking the maximum value at the end (argmax), which I would like to avoid:

  model = Sequential()
  model.add(Dense(30, input_shape=(30,), activation='relu', kernel_initializer='he_uniform'))
  # output classes
  model.add(Dense(3, kernel_regularizer=regularizers.l2(0.1)))
  # the activation is linear by default, which works; softmax makes the accuracy be stuck 33% if targeting 3 classes, or 25% if targeting 4.
  #model.add(Activation('softmax')) 
  model.compile(loss='categorical_hinge', optimizer=keras.optimizers.Adam(lr=1e-3), metrics=['accuracy'])

Any ideas on how to tackle this untrained-class problem? Something like Plat scaling or Temperature scaling would work, if I can still save the model as onnx.

Apidcloud
  • 3,558
  • 2
  • 21
  • 23

1 Answers1

0

As I suspected, got softmax to work by scaling the features (input) of the model. No need for stop gradient or anything. I was specifically using really big numbers, which despite training well, were preventing softmax (logistic regression) to work properly. The scaling of the features can be done, for instance, through the following code:

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_std = scaler.fit_transform(X)

By doing this the output of the SVM-like model using keras is outputting probabilities as originally intended.

Apidcloud
  • 3,558
  • 2
  • 21
  • 23