2

I'm a relatively new user to sklearn and have question about using train_test_split from sklearn.model_selection. I have a large dataframe that has shape of (96350, 156). In my dataframe is column named CountryName that contains 160 countries, each country has about 600 instances.

Input:

df['CountryName'].unique()

Output:

array(['Aruba', 'Afghanistan', 'Angola', 'Albania', 'Andorra',
       'United Arab Emirates', 'Argentina', 'Australia', 'Austria',
       'Azerbaijan', 'Belgium', 'Benin', 'Burkina Faso', 'Bangladesh',
       'Bulgaria', 'Bahrain', 'Bahamas', 'Bosnia and Herzegovina',
...
       'Slovenia', 'Sweden', 'Eswatini', 'Seychelles', 'Chad', 'Togo',
       'Thailand', 'Trinidad and Tobago', 'Tunisia', 'Turkey', 'Taiwan',
       'Tanzania', 'Uganda', 'Ukraine', 'Uruguay', 'United States',
       'Uzbekistan', 'Venezuela', 'Vietnam', 'South Africa', 'Zambia',
       'Zimbabwe'], dtype=object)

How can I implement train_test_split on the level of countries and not on the level of instances ? To better understand my question I made quick table which is my dataframe. How can i perform train_test_split on country for example Aruba (so we get 70% trained data and 30% test data from this Aruba country), and do this for all countries and at the end add these trained/testing (X_train, X_test, y_train and y_test) data together in another dataframe?

To visualize:

(____part of X dataset____)   (y dataset)   
CountryName  value1  value2 ... valueN
   Aruba       1       3    ...   3
   Aruba       2       4    ...   6
   Aruba       3       4    ...   1
    ...       ...     ...   ...  ...
   Sweden      5       3    ...   2
   Sweden      4       7    ...   2
    ...       ...     ...   ...  ...
  Zimbabwe     2       3    ...   9
  Zimbabwe     1       2    ...   8 
  Zimbabwe     5       1    ...   1
  Zimbabwe     5       3    ...   3
    ...       ...     ...   ...  ...
leskovecg
  • 83
  • 8

1 Answers1

5

Use stratify as parameter of train_test_split:

Something like:

X_train, X_test = train_test_split(df, test_size=.3, stratify=df['CountryName'])

Update: with your data:

>>> train_test_split(df, test_size=.3, stratify=df['CountryName'])
[  CountryName  value1  value2  valueN
 3      Sweden       5       3       2
 7    Zimbabwe       5       1       1
 0       Aruba       1       3       3
 1       Aruba       2       4       6
 8    Zimbabwe       5       3       3
 5    Zimbabwe       2       3       9,

   CountryName  value1  value2  valueN
 6    Zimbabwe       1       2       8
 2       Aruba       3       4       1
 4      Sweden       4       7       2]
Corralien
  • 109,409
  • 8
  • 28
  • 52
  • Sorry, actually you are right. Thanks a lot! Sub-question: do you know how can I regroup my countries back, because now when I output my dataframe I see mixed order ? – leskovecg Mar 10 '22 at 14:13
  • Try `pd.concat([X_train, X_test]).sort_index()` – Corralien Mar 10 '22 at 14:16