-1

I have data that includes X features and Y - binary class ( 0 or 1 ) My problem is imbalanced so I want to make sure my y_test after the split will contain about 50% of the samples classified as 1 after the split.

I tried to use train_test_split stratify but my 0/1 ratio is below 50%, doesn't work.

Any suggestions?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
brian rik
  • 25
  • 1
  • 7

1 Answers1

2

You shouldn't affect train-test split because of imbalance. Train-test split has to correspond to actual testing distribution. If your problem is imbalanced - so should your test set be!

What you can change though is a metric you use and/or training regime, e.g.:

Both these will technically same the same effect of treating classes in an equally important way, but you do not have to "split things" differently.

And if you really insist on splitting data in such an odd way just do it by hand

import numpy as np

def odd_split(X, y, minority_class=1, minority_test_size=0.1):
  minority_indices = np.where(y==minority_class)[0]
  majority_indices = np.where(y!=minority_class)[0]
    
  n = max(1, int(minority_test_size* len(minority_indices)))
  selected = np.random.choice(range(len(minority_indices)), n, replace=False)
  test_minority_indices = minority_indices[selected]
  assert (y[test_minority_indices] == minority_class).all()
  
  selected = np.random.choice(range(len(majority_indices)), n, replace=False)
  test_majority_indices = majority_indices[selected]
  assert (y[test_majority_indices ] != minority_class).all()
  
  test_indices = np.concatenate((test_minority_indices, test_majority_indices))
  train_indices = np.array([i for i in range(len(y)) if i not in set(test_indices)])
  
  return X[train_indices], y[train_indices], X[test_indices], y[test_indices]
  

from collections import Counter  
X = np.random.normal(size=(1000, 2))  
y = np.random.choice([0, 1], p=[0.9, 0.1], size=1000)
print('Whole', Counter(y))

X_train, y_train, X_test, y_test = odd_split(X, y)
print('Train', Counter(y_train))
print('Test', Counter(y_test))

Which gives

Whole Counter({0: 886, 1: 114})
Train Counter({0: 875, 1: 103})
Test Counter({1: 11, 0: 11})
lejlot
  • 64,777
  • 8
  • 131
  • 164
  • Thanks! I'm interested in making the split and than using SMOTE to balance my train data. but the split leaves out very small number of class 1 instances so I thought I can "buff" it before the split.. – brian rik Jun 14 '22 at 05:09