2

I am performing multi-class text classification using BERT in python. The dataset that I am using for retraining my model is highly imbalanced. Now, I am very clear that the class imbalance leads to a poor model and one should balance the training set by undersampling, oversampling, etc. before model training.

However, it is also a fact that the distribution of the training set should be similar to the distribution of the production data.

Now, if I am sure that the data thrown at me in the production environment will also be imbalanced, i.e., the samples to be classified will likely belong to one or more classes as compared to some other classes, should I balance my training set?

OR

Should I keep the training set as it is as I know that the distribution of the training set is similar to the distribution of data that I will encounter in the production?

Please give me some ideas, or provide some blogs or papers for understanding this problem.

fam
  • 583
  • 3
  • 14
  • For the record, it's not a requirement in supervised learning to have the same distribution in the training set and the production set. What is necessary is for the *test set*, i.e. the data on which the system is evaluated, to have the same distribution as the production set. So it's ok to resample the training set, as long as the test set has the true distribution (whether it's a good idea to resample is a different matter, of course). – Erwan Jul 15 '22 at 15:20
  • Thank you! your point regarding the similar distributions for train and test make sense. However, I would like to disagree with your first point (the one regarding the similar distribution of training and production set). If what you stated is true, then why do we have to retrain the models upon the availability of new data. Do you think that the model is suitable for making predictions if our training and production sets have different distributions? Can you please clarify this point? If possible, I would like to have a detailed discussion on this. – fam Jul 18 '22 at 05:28
  • Sure, I'm aware that this often a source of confusion. You're right that one would usually use the same distribution between training and production *because this is the most likely to perform well* when evaluated on a test set. But if for some reason one discovers that using some different distrib (e.g. resampled) in the training set works even better on the test set, then it's perfectly fine to use this distrib (the whole resampling method is based on this btw). What matters to keep the test set the same, because the test set represents the problem one tries to solve. – Erwan Jul 18 '22 at 08:45
  • Changing the test set would mean *changing the problem*, possibly making it easier or harder to solve and anyway not comparable to the original problem. Changing the training set means changing the method one tries to use for the problem: some methods work better than others, but any method which works is ok. – Erwan Jul 18 '22 at 08:47

3 Answers3

2

Class imbalance is not a problem by itself, the problem is too few minority class' samples make it harder to describe its statistical distribution, which is especially true for high-dimensional data (and BERT embeddings have 768 dimensions IIRC).

Additionally, logistic function tends to underestimate the probability of rare events (see e.g. https://gking.harvard.edu/files/gking/files/0s.pdf for the mechanics), which can be offset by selecting a classification threshold as well as resampling.

There's quite a few discussions on CrossValidated regarding this (like https://stats.stackexchange.com/questions/357466). TL;DR:

  • while too few class' samples may degrade the prediction quality, resampling is not guaranteed to give an overall improvement; at least, there's no universal recipe to a perfect resampling proportion, you'll have to test it out for yourself;
  • however, real life tasks often weigh classification errors unequally: resampling may help improving certain class' metrics at the cost of overall accuracy. Same applies to classification threshold selection however.
dx2-66
  • 2,376
  • 2
  • 4
  • 14
  • Thank you for the clarification on multiple points. But, my question still remains. Can you please elaborate by considering dataset balancing and distribution similarity in training and production simultaneously. Also, how is prediction based on logistic regression related to this problem? – fam Jul 18 '22 at 05:36
  • If you've got _sufficient_ minority samples, there should be no reason not to use the same distribution. If you do not, your estimation will be biased (and neural network classifiers still use sigmoid or its softmax generalization most of the time). When doing resampling it's hard to apply 'just enough' correction for that bias. That's probably alright if you're mostly interested in a minority class, but otherwize I'd personally try to balance things out using classification thresholds only: it can be adjusted trivially and does not require retraining. – dx2-66 Jul 18 '22 at 10:23
1

This depends on the goal of your classification:

  • Do you want a high probability that a random sample is classified correctly? -> Do not balance your training set.
  • Do you want a high probability that a random sample from a rare class is classified correctly? -> balance your training set or apply weighting during training increasing the weights for rare classes.

For example in web applications seen by clients, it is important that most samples are classified correctly, disregarding rare classes, whereas in the case of anomaly detection/classification, it is very important that rare classes are classified correctly.

Keep in mind that a highly imbalanced dataset tends to always predicting the majority class, therefore increasing the number or weights of rare classes can be a good idea, even without perfectly balancing the training set..

chefhose
  • 2,399
  • 1
  • 21
  • 32
1

P(label | sample) is not the same as P(label).

P(label | sample) is your training goal.

In the case of gradient-based learning with mini-batches on models with large parameter space, rare labels have a small footprint on the model training. So, your model fits in P(label).

To avoid fitting to P(label), you can balance batches. Overall batches of an epoch, data looks like an up-sampled minority class. The goal is to get a better loss function that its gradients move parameters toward a better classification goal.

UPDATE

I don't have any proof to show this here. It is perhaps not an accurate statement. With enough training data (with respect to the complexity of features) and enough training steps you may not need balancing. But most language tasks are quite complex and there is not enough data for training. That was the situation I imagined in the statements above.

Mehdi
  • 4,202
  • 5
  • 20
  • 36
  • Thank you! Can you please elaborate "So, your model fits in `P(label)`"? It will be great if you can provide the source as well. – fam Jul 18 '22 at 05:40