1

I have a code snippet for implementing SVM from scratch (binary classefier) These code wouldn't work on minist because output predictions 0 or 1


class SVM:
    def __init__(self, learning_rate=1e-3, lambda_param=1e-2, n_iters=1000):
        self.lr = learning_rate
        self.lambda_param = lambda_param
        self.n_iters = n_iters
        self.w = None
        self.b = None

    def _init_weights_bias(self, X):
        n_features = X.shape[1]
        self.w = np.zeros(n_features)
        self.b = 0

    def _get_cls_map(self, y):
        return np.where(y <= 0, -1, 1)

    def _satisfy_constraint(self, x, idx):
        linear_model = np.dot(x, self.w) + self.b 
        return self.cls_map[idx] * linear_model >= 1
    
    def _get_gradients(self, constrain, x, idx):
        if constrain:
            dw = self.lambda_param * self.w
            db = 0
            return dw, db
        
        dw = self.lambda_param * self.w - np.dot(self.cls_map[idx], x)
        db = - self.cls_map[idx]
        return dw, db
    
    def _update_weights_bias(self, dw, db):
        self.w -= self.lr * dw
        self.b -= self.lr * db
    
    def fit(self, X, y):
        self._init_weights_bias(X)
        self.cls_map = self._get_cls_map(y)

        for _ in range(self.n_iters):
            for idx, x in enumerate(X):
                constrain = self._satisfy_constraint(x, idx)
                dw, db = self._get_gradients(constrain, x, idx)
                self._update_weights_bias(dw, db)
    
    def predict(self, X):
        estimate = np.dot(X, self.w) + self.b
        prediction = np.sign(estimate)
        return np.where(prediction == -1, 0, 1)

What can I do to classify more than 20 classes In other words make it multiclass SVM any hint please ?

Not Bedo
  • 11
  • 2

1 Answers1

0

SVM does not handle multiclass cases natively. Most implementations just fit as much binary classifiers as there are classes (one vs rest) or as much as there are possible pairs (one vs one).

A more elegant approach would be using Hamming codes (see e.g. https://github.com/christianversloot/machine-learning-articles/blob/main/using-error-correcting-output-codes-for-multiclass-svm-classification.md), though it sounds like quite a challenge to implement from scratch.

dx2-66
  • 2,376
  • 2
  • 4
  • 14