2

I found out that at least for the scenario below, doing over is much slower (2~3x) than doing groupby/agg + explode. And, the results are exactly the same.

Based on this finding, I have the following questions:

  • Is such behaviour as expected? If so, should we always do a 2-step procedure (groupby/agg + explode) instead of using over directly?
  • Or, does this mean that there may be some room to optimize over?
  • Or, the performance between these two approaches really depends on the problem setup and users should try to see which approach fits the better?
import time

import numpy as np
import polars as pl
from polars.testing import assert_frame_equal

## setup
rng = np.random.default_rng(1)

nrows = 20_000_000
df = pl.DataFrame(
    dict(
        id=rng.integers(1, 50, nrows),
        id2=rng.integers(1, 500, nrows),
        v=rng.normal(0, 1, nrows),
        v1=rng.normal(0, 1, nrows),
        v2=rng.normal(0, 1, nrows),
        v3=rng.normal(0, 1, nrows),
        v4=rng.normal(0, 1, nrows),
        v5=rng.normal(0, 1, nrows),
        v6=rng.normal(0, 1, nrows),
        v7=rng.normal(0, 1, nrows),
        v8=rng.normal(0, 1, nrows),
        v9=rng.normal(0, 1, nrows),
        v10=rng.normal(0, 1, nrows),
    )
)

## over
start = time.perf_counter()
res = (
    df.lazy()
    .select(
        [
            "id",
            "id2",
            *[
                (pl.col(f"v{i}") - pl.col(f"v{i}").mean().over(["id", "id2"]))
                / pl.col(f"v{i}").std().over(["id", "id2"])
                for i in range(1, 11)
            ],
        ]
    )
    .collect()
)
time.perf_counter() - start
# 8.541702497983351

## groupby/agg + explode
start = time.perf_counter()
res2 = (
    df.lazy()
    .groupby(["id", "id2"])
    .agg(
        [
            (pl.col(f"v{i}") - pl.col(f"v{i}").mean()) / pl.col(f"v{i}").std()
            for i in range(1, 11)
        ],
    )
    .explode(pl.exclude(["id", "id2"]))
    .collect()
)
time.perf_counter() - start
# 3.1841439900454134

## compare results
assert_frame_equal(res.sort(["id", "id2"]), res2.sort(["id", "id2"])[res.columns])
lebesgue
  • 837
  • 4
  • 13

0 Answers0