Can anyone tell me what is the simplest way to apply class_weight
in Keras when the dataset is unbalanced please?
I only have two classes in my target.
Thanks.
Can anyone tell me what is the simplest way to apply class_weight
in Keras when the dataset is unbalanced please?
I only have two classes in my target.
Thanks.
The class_weight
parameter of the fit()
function is a dictionary mapping classes to a weight value.
Lets say you have 500 samples of class 0 and 1500 samples of class 1 than you feed in class_weight = {0:3 , 1:1}. That gives class 0 three times the weight of class 1.
train_generator.classes
gives you the proper class names for your weighting.
If you want to calculate this programmatically you can use scikit-learn´s sklearn.utils.compute_class_weight().
The function looks at the distribution of labels and produces weights to equally penalize under or over-represented classes in the training set.
See also this useful thread here: https://github.com/fchollet/keras/issues/1875
And this thread might also be of help: Is it possible to automatically infer the class_weight from flow_from_directory in Keras?
Using class_weight from sklearn kit.
Im also using this method to deal with the imbalance data
from sklearn.utils import class_weight
class_weight = class_weight.compute_class_weight('balanced'
,np.unique(Y_train)
,Y_train)
then model.fit
Classifier.fit(train_X,train_Y,batch_size = 100, epochs = 10
,validation_data= (test_X,test_Y),class_weight = class_weight )
1- Define a dictionary with your labels and their associated weights
class_weight = {0: 0.1,
1: 1.,
2: 2.}
2- Feed the dictionary as a parameter:
model.fit(X_train, Y_train, batch_size = 100, epochs = 10, class_weight=class_weight)
Are you asking about the right weighting to apply or how to do that in the code? The code is simple:
class_weights = {}
for i in range(2):
class_weights[i] = your_weight
and then you pass the argument class_weight=class_weights
in model.fit
.
The right weighting to use would be some sort of inverse frequency; you can also do a bit of trial and error.
class weights takes a dictionary type.
from collections import Counter
itemCt = Counter(trainGen.classes)
maxCt = float(max(itemCt.values()))
cw = {clsID : maxCt/numImg for clsID, numImg in itemCt.items()}