I'm trying to train a LGBClassifier for multiclass task. I tried first working directly with LightGBM API and set the model and training as follows:
LightGBM API
train_data = lgb.Dataset(X_train, (y_train-1))
test_data = lgb.Dataset(X_test, (y_test-1))
params = {}
params['learning_rate'] = 0.3
params['boosting_type'] = 'gbdt'
params['objective'] = 'multiclass'
params['metric'] = 'softmax'
params['max_depth'] = 10
params['num_class'] = 8
params['num_leaves'] = 500
lgb_train = lgb.train(params, train_data, 200)
# AFTER TRAINING THE MODEL
y_pred = lgb_train.predict(X_test)
y_pred_class = [np.argmax(line) for line in y_pred]
y_pred_class = np.asarray(y_pred_class) + 1
This is how the confussion matrix looks:
Sklearn API
Then I tried to move to Sklearn API to be able to use other tools. This is the code I used:
lgb_clf = LGBMClassifier(objective='multiclass',
boosting_type='gbdt',
max_depth=10,
num_leaves=500,
learning_rate=0.3,
eval_metric=['accuracy','softmax'],
num_class=8,
n_jobs=-1,
early_stopping_rounds=100,
num_iterations=500)
clf_train = lgb_clf(X_train, (y_train-1), verbose=1, eval_set=[(X_train, (y_train-1)), (X_test, (y_test-1)))])
# TRAINING: I can see overfitting is happening
y_pred = clf_train.predict(X_test)
y_pred = [np.argmax(line) for line in y_pred]
y_pred = np.asarray(y_pred) + 1
And this is the confusion matrix in this case:
Notes
- I need to substract 1 from y_train as my classes start at 1 and LightGBM was complaining about this.
- When I try a RandomSearch or a GridSearch I always obtain the same result as the last confusion matrix.
- I have check different questions here but none solve this issue.
Questions
- Is there anything that I'm missing out when implementing the model in Sklearn API?
- Why do I obtain good results (maybe with overfitting) with LightGBM API?
- How can I achieve the same results with the two APIs?
Thanks in advance.
UPDATE It was my mistake. I thought the output in both APIs would be the same but it doesn't seem like that. I just removed the np.argmax() line when predicting with Sklearn API. It seems this API already predict directly the class. Don't remove the question in case someone else is dealing with similar issues.