-1

I'm trying to create a classification model to predict one of two classes: "Hit" or "Miss".

The dataset contains around 80% "Hits" thus it is highly unbalanced, so models such as classification trees (ctree from party package) choose to predict all outcomes as "Hit" and obtain 80% accuracy.

I tried undersampling and SMOTE algorithm without success.

How can I change the cost matrix in order to penalize the model when it classifies a "Miss" as a "Hit"?

Brett DeWoody
  • 59,771
  • 29
  • 135
  • 184
user3639100
  • 336
  • 1
  • 11
  • They choose to predict all outcomes as class "hit" because they typically use a 0.5 (predicted) probability threshold. Try to use the predicted probability and not the predicted class and use your own threshold(s) to predict. See how false positives and false negatives change when you classify/predict as "hit" anything with predicted probability 0.7 or above, 0.8 or above, 0.9 or above, etc. – AntoniosK Jan 22 '18 at 19:37

1 Answers1

0

You can do that with the weights argument to ctree. Since you do not provide any data, I will illustrate with bogus data.

library(party)

## Some bogus data
set.seed(42)
class = factor(sample(1:2, 500, replace=TRUE, prob=c(0.8, 0.2)) )
x1 = rnorm(500)
x2 = rnorm(500, 0.7, 0.9)
x = ifelse(class == 1, x1, x2)
y1 = rnorm(500)
y2 = rnorm(500, 0.7, 0.9)
y = ifelse(class == 1, y1, y2)
Imbalanced = data.frame(x,y,class)

Just using ctree on this data makes it classify all data as class 1.

CT1 = ctree(class ~ ., data=Imbalanced)
table(predict(CT1))
  1   2 
500   0 

But if you set the weights, you can make it find more of the class 2 data.

W = ifelse(class==1, 1, 2)
CT2 = ctree(class ~ ., data=Imbalanced, weights=W)
table(predict(CT2), class)
   class
      1   2
  1 336  44
  2  63  57

Notice that the overall accuracy has gone down but we got more of the class 2 points correctly classified. If you use a really big weighting factor, you can get almost all of the class 2 points (at the expense of even greater loss of overall accuracy).

W = ifelse(class==1, 1, 5)
CT3 = ctree(class ~ ., data=Imbalanced, weights=W)
table(predict(CT3), class)
   class
      1   2
  1 178   4
  2 221  97
G5W
  • 36,531
  • 10
  • 47
  • 80