One (somewhat) easier way to accomplish this is to add the @
as a suffix to your Categorical
columns, and then target the @_
with a simple list comprehension.
Let's start with this data.
import polars as pl
df = (
pl.DataFrame([
pl.Series(
name='driver_age',
values=['16_to_25', '25_to_34', '35_to_45', '45_to_55'],
dtype=pl.Categorical),
pl.Series(
name='marital_status',
values=['S', 'M'] * 2,
dtype=pl.Categorical
),
pl.Series(
name='col1',
values=[1, 2, 3, 4],
),
pl.Series(
name='col2',
values=[10, 20, 30, 40],
),
])
)
df
shape: (4, 4)
┌────────────┬────────────────┬──────┬──────┐
│ driver_age ┆ marital_status ┆ col1 ┆ col2 │
│ --- ┆ --- ┆ --- ┆ --- │
│ cat ┆ cat ┆ i64 ┆ i64 │
╞════════════╪════════════════╪══════╪══════╡
│ 16_to_25 ┆ S ┆ 1 ┆ 10 │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ 25_to_34 ┆ M ┆ 2 ┆ 20 │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ 35_to_45 ┆ S ┆ 3 ┆ 30 │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ 45_to_55 ┆ M ┆ 4 ┆ 40 │
└────────────┴────────────────┴──────┴──────┘
We use the suffix
Expression to add a @
to the end of the column names that are Categorical and create our dummy variables.
df = (
pl.get_dummies(
df
.select([
pl.exclude(pl.Categorical),
pl.col(pl.Categorical).suffix('@'),
]),
columns=[s.name + '@' for s in df.select(pl.col(pl.Categorical))]
)
)
df
shape: (4, 8)
┌──────┬──────┬──────────────────────┬──────────────────────┬──────────────────────┬──────────────────────┬───────────────────┬───────────────────┐
│ col1 ┆ col2 ┆ driver_age@_16_to_25 ┆ driver_age@_25_to_34 ┆ driver_age@_35_to_45 ┆ driver_age@_45_to_55 ┆ marital_status@_M ┆ marital_status@_S │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ u8 ┆ u8 ┆ u8 ┆ u8 ┆ u8 ┆ u8 │
╞══════╪══════╪══════════════════════╪══════════════════════╪══════════════════════╪══════════════════════╪═══════════════════╪═══════════════════╡
│ 1 ┆ 10 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ 20 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ 30 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 1 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4 ┆ 40 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 0 │
└──────┴──────┴──────────────────────┴──────────────────────┴──────────────────────┴──────────────────────┴───────────────────┴───────────────────┘
From here, it's a one-liner to change the column names:
df.columns = [col_nm.replace('@_', '@') for col_nm in df.columns]
df
shape: (4, 8)
┌──────┬──────┬─────────────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────────────────┬──────────────────┐
│ col1 ┆ col2 ┆ driver_age@16_to_25 ┆ driver_age@25_to_34 ┆ driver_age@35_to_45 ┆ driver_age@45_to_55 ┆ marital_status@M ┆ marital_status@S │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ u8 ┆ u8 ┆ u8 ┆ u8 ┆ u8 ┆ u8 │
╞══════╪══════╪═════════════════════╪═════════════════════╪═════════════════════╪═════════════════════╪══════════════════╪══════════════════╡
│ 1 ┆ 10 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ 20 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ 30 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 1 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4 ┆ 40 ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 0 │
└──────┴──────┴─────────────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────────────────┴──────────────────┘
It's not done in Lazy mode, but then again, the get_dummies
is also not available in Lazy mode.