2

I am currently trying to train an image classification model using Pytorch densenet121 with 4 labels (A, B, C, D). I have 224000 images and each image is labeled in the form of [1, 0, 0, 1] (Label A and D are present in the image). I have replaced the last dense layer of densenet121. The model is trained using Adam optimizer, LR of 0.0001 (with the decay of a factor of 10 per epoch), and trained for 4 epochs. I will try more epochs after I am confident that the class imbalanced issue is resolved.

The estimated number of positive classes is [19000, 65000, 38000, 105000] respectively. When I trained the model without class balancing and weights (with BCELoss), I have a very low recall for label A and C (in fact the True Positive TP and False Positive FP is less than 20)

I have tried 3 approaches to deal with the class imbalance after an extensive search on Google and Stackoverflow.

Approach 1: Class weights I have tried to implement class weights by using the ratio of negative samples to positive samples.

y = train_df[CLASSES];
pos_weight = (y==0).sum()/(y==1).sum()

pos_weight = torch.Tensor(pos_weight)
if torch.cuda.is_available():
    pos_weight = pos_weight.cuda()
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

The resultant class weights are [10.79, 2.45, 4.90, 1.13]. I am getting the opposite effect; having too many positive predictions which result in low precision.

Approach 2: Changing logic for class weights

I have also tried to get class weights by getting the proportion of the positive samples in the dataset and getting the inverse. The resultant class weights are [11.95, 3.49, 5.97, 2.16]. I am still getting too many positive predictions.

class_dist = y.apply(pd.Series.value_counts)
class_dist_norm = class_dist.loc[1.0]/class_dist.loc[1.0].sum()
pos_weight = 1/class_dist_norm

Approach 3: Focal Loss

I have also tried Focal Loss with the following implementation (but still getting too many positive predictions). I have used the class weights for the alpha parameter. This is referenced from https://gist.github.com/f1recracker/0f564fd48f15a58f4b92b3eb3879149b but I made some modifications to suit my use case better.

class FocalLoss(nn.CrossEntropyLoss):
    ''' Focal loss for classification tasks on imbalanced datasets '''

    def __init__(self, alpha=None, gamma=1.5, ignore_index=-100, reduction='mean', epsilon=1e-6):
        super().__init__(weight=alpha, ignore_index=ignore_index, reduction='mean')
        self.reduction = reduction
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha = alpha

    def forward(self, input_, target):
        # cross_entropy = super().forward(input_, target)
        # Temporarily mask out ignore index to '0' for valid gather-indices input.
        # This won't contribute final loss as the cross_entropy contribution
        # for these would be zero.
        target = target * (target != self.ignore_index).long()

        # p_t = p if target = 1, p_t = (1-p) if target = 0, where p is the probability of predicting target = 1

        p_t = input_ * target + (1 - input_) * (1 - target)

        # Loss = -(alpha)( 1 - p_t)^gamma log(p_t), where -log(p_t) is cross entropy => loss = (alpha)(1-p_t)^gamma * cross_entropy (Epsilon added to prevent error with log(0) when class probability is 0)
        if self.alpha != None:
            loss = -1 * self.alpha * torch.pow(1 - p_t, self.gamma) * torch.log(p_t + self.epsilon)
        else:
            loss = -1 * torch.pow(1 - p_t, self.gamma) * torch.log(p_t + self.epsilon)


        if self.reduction == 'mean':
            return torch.mean(loss)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            return loss

One thing to note is that the loss using stagnant after the first epoch, but the metrics varied between epochs.

I have considered undersampling and oversampling but I am unsure of how to proceed due to the fact that each image can have more than 1 label. One possible method is to oversample images with only 1 label by replicating them. But I am concerned that the model will only generalize on images with 1 label but perform poorly on images with multiple labels.

Therefore I would like to ask if there are methods that I should try, or did I make any mistakes in my approaches.

Any advice will be greatly appreciated.

Thank you!

Christoph Rackwitz
  • 11,317
  • 4
  • 27
  • 36
Wei Feng
  • 23
  • 2

0 Answers0