Assuming I already have a predicate expression, how do I filter with that predicate, but apply it only within groups? For example, the predicate might be to keep all rows equal to the maximum or within a group. (There could be multiple rows kept in a group if there is a tie.)
With my dplyr experience, I thought that I could just .groupby
and then .filter
, but that does not work.
import polars as pl
df = pl.DataFrame(dict(x=[0, 0, 1, 1], y=[1, 2, 3, 3]))
expression = pl.col("y") == pl.col("y").max()
df.groupby("x").filter(expression)
# AttributeError: 'GroupBy' object has no attribute 'filter'
I then thought I could apply .over
to the expression, but that does not work either.
import polars as pl
df = pl.DataFrame(dict(x=[0, 0, 1, 1], y=[1, 2, 3, 3]))
expression = pl.col("y") == pl.col("y").max()
df.filter(expression.over("x"))
# RuntimeError: Any(ComputeError("this binary expression is not an aggregation:
# [(col(\"y\")) == (col(\"y\").max())]
# pherhaps you should add an aggregation like, '.sum()', '.min()', '.mean()', etc.
# if you really want to collect this binary expression, use `.list()`"))
For this particular problem, I can invoke .over
on the max
, but I don't know how to apply this to an arbitrary predicate I don't have control over.
import polars as pl
df = pl.DataFrame(dict(x=[0, 0, 1, 1], y=[1, 2, 3, 3]))
expression = pl.col("y") == pl.col("y").max().over("x")
df.filter(expression)
# shape: (3, 2)
# ┌─────┬─────┐
# │ x ┆ y │
# │ --- ┆ --- │
# │ i64 ┆ i64 │
# ╞═════╪═════╡
# │ 0 ┆ 2 │
# ├╌╌╌╌╌┼╌╌╌╌╌┤
# │ 1 ┆ 3 │
# ├╌╌╌╌╌┼╌╌╌╌╌┤
# │ 1 ┆ 3 │
# └─────┴─────┘