Let start with some dummy data:
n = 100
seed = 0
df = pl.DataFrame(
{
"groups": (pl.int_range(0, n, eager=True) % 5).shuffle(seed=seed),
"values": pl.int_range(0, n, eager=True).shuffle(seed=seed)
}
)
df
shape: (100, 2)
┌────────┬────────┐
│ groups ┆ values │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞════════╪════════╡
│ 0 ┆ 55 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 0 ┆ 40 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2 ┆ 57 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 99 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ ... ┆ ... │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2 ┆ 87 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1 ┆ 96 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 3 ┆ 43 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 44 │
└────────┴────────┘
This gives us 100 / 5, is 5 groups of 20 elements. Let's verify that:
df.groupby("groups").agg(pl.count())
shape: (5, 2)
┌────────┬───────┐
│ groups ┆ count │
│ --- ┆ --- │
│ i64 ┆ u32 │
╞════════╪═══════╡
│ 1 ┆ 20 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 20 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 4 ┆ 20 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 20 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 0 ┆ 20 │
└────────┴───────┘
Sample our data
Now we are going to use a window function to take a sample of our data.
df.filter(
pl.int_range(0, pl.count()).shuffle().over("groups") < 10
)
shape: (50, 2)
┌────────┬────────┐
│ groups ┆ values │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞════════╪════════╡
│ 0 ┆ 85 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 0 ┆ 0 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 84 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 19 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ ... ┆ ... │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2 ┆ 87 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1 ┆ 96 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 3 ┆ 43 │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 4 ┆ 44 │
└────────┴────────┘
For every group in over("group")
the pl.int_range(0, pl.count())
expression creates an index row. We then shuffle
that range so that we take a sample and not a slice. Then we only want to take the index values that are lower than 10. This creates a boolean mask
that we can pass to the filter
method.