I have a dataset of ~2m observations which I need to split into training, validation and test sets in the ratio 60:20:20. A simplified excerpt of my dataset looks like this:
+---------+------------+-----------+-----------+
| note_id | subject_id | category | note |
+---------+------------+-----------+-----------+
| 1 | 1 | ECG | blah ... |
| 2 | 1 | Discharge | blah ... |
| 3 | 1 | Nursing | blah ... |
| 4 | 2 | Nursing | blah ... |
| 5 | 2 | Nursing | blah ... |
| 6 | 3 | ECG | blah ... |
+---------+------------+-----------+-----------+
There are multiple categories - which are not evenly balanced - so I need to ensure that the training, validation and test sets all have the same proportions of categories as in the original dataset. This part is fine, I can just use StratifiedShuffleSplit
from the sklearn
library.
However, I also need to ensure that the observations from each subject are not split across the training, validation and test datasets. All the observations from a given subject need to be in the same bucket to ensure my trained model has never seen the subject before when it comes to validation/testing. E.g. every observation of subject_id 1 should be in the training set.
I can't think of a way to ensure a stratified split by category, prevent contamination (for want of a better word) of subject_id across datasets, ensure a 60:20:20 split and ensure that the dataset is somehow shuffled. Any help would be appreciated!
Thanks!
EDIT:
I've now learnt that grouping by a category and keeping groups together across dataset splits can also be accomplished by sklearn
through the GroupShuffleSplit
function. So essentially, what I need is a combined stratified and grouped shuffle split i.e. StratifiedGroupShuffleSplit
which does not exist. Github issue: https://github.com/scikit-learn/scikit-learn/issues/12076