0

How can I take elements by index within each group of a Polars DataFrame? For example, if I wanted to get the first and third element of each group, I might try something like this:

import polars as pl

df = pl.DataFrame(dict(x=[1,0,1,0,1,0], y=[1,2,3,4,5,6]))

df.groupby('x').take([0,2])
# AttributeError: 'GroupBy' object has no attribute 'take'

But that does not work, obviously.

drhagen
  • 8,331
  • 8
  • 53
  • 82

1 Answers1

5
df.groupby("x").agg(pl.all().take([0, 2]))
shape: (2, 2)
┌─────┬────────────┐
│ x   ┆ y          │
│ --- ┆ ---        │
│ i64 ┆ list [i64] │
╞═════╪════════════╡
│ 1   ┆ [1, 5]     │
├╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 0   ┆ [2, 6]     │
└─────┴────────────┘

You can use explode to flatten the list column if needed.

df.groupby("x").agg(pl.all().take([0, 2])).explode('y')
shape: (4, 2)
┌─────┬─────┐
│ x   ┆ y   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1   ┆ 1   │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ 1   ┆ 5   │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ 0   ┆ 2   │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ 0   ┆ 6   │
└─────┴─────┘

The documentation for the take expression has a similar example.