4

I'm trying to port a pandas script to polars. I have a dataset that looks like that

sid,roi,endpoint,value,std,voxel_count
4213-a3_bl,AF_L,afd_along,0.40,0.21,57334
4213-a3_bl,AF_L,radfODF,0.08,0.045,57334
4213-a3_bl,AF_R,afd_along,0.42,0.22,53916
4213-a3_bl,AF_R,radfODF,0.08,0.04,53916
4213-a3_bl,CC_1,afd_along,,,
4213-a3_bl,CC_1,radfODF,,,
4213-a3_bl,CC_2a,afd_along,0.54,0.30,3264
4225-a3_bl,CC_2a,radfODF,0.06,0.04,3264
4225-a3_bl,CC_2b,afd_along,0.47,0.24,18833
... thousands of rows ...

I want to add a column based on a groupby

df.filter(pl.col('roi') == 'wm_mask').groupby('sid').first()

                 roi endpoint     value       std  voxel_count
sid                                                              
4213-a3_bl   wm_mask       ad  0.001074  0.000237       602620
4225-a3_bl   wm_mask       ad  0.001071  0.000242       718758
4229-a3_bl   wm_mask       ad  0.001045  0.000243       579756
4473-a3_bl   wm_mask       ad  0.001059  0.000259       662894
4654-a3_bl   wm_mask       ad  0.001083  0.000234       562841
...              ...      ...       ...       ...          ...

Now I want to add this new voxel_count values that correspond to the right sid, which should give something like

sid,roi,endpoint,value,std,voxel_count,     wm_mask__count
4213-a3_bl,AF_L,afd_along,0.40,0.21,57334,  602620
4213-a3_bl,AF_L,radfODF,0.08,0.045,57334,   602620
4213-a3_bl,AF_R,afd_along,0.42,0.22,53916,  602620
4213-a3_bl,AF_R,radfODF,0.08,0.04,53916,    602620
4213-a3_bl,CC_1,afd_along,,,,               602620
4213-a3_bl,CC_1,radfODF,,,,                 602620
4213-a3_bl,CC_2a,afd_along,0.54,0.30,3264,  602620
4225-a3_bl,CC_2a,radfODF,0.06,0.04,3264,    718758
4225-a3_bl,CC_2b,afd_along,0.47,0.24,18833, 718758
... thousands of rows ...

I tried various things but I always end up with AttributeError: _s. Can you please tell me how to express that in polars?

If it can help, the associated pandas lines are

df = df.set_index("sid", drop=True)
df_wm_volumes = df[df.roi == "wm_mask"].groupby("sid", as_index=True).first()
df["wm_mask__volume"] = df_wm_volumes["voxel_count"]
df = df.reset_index(drop=False)
Nil
  • 2,345
  • 1
  • 26
  • 33

2 Answers2

5

We can most easily accomplish this with a left join in Polars.

First, I'll add the following two lines to your input (so that we have some rows where roi == 'wm_mask'.

4213-a3_bl,wm_mask,,,,602620
4225-a3_bl,wm_mask,,,,718758

So that our data looks like:

shape: (11, 6)
┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┐
│ sid        ┆ roi     ┆ endpoint  ┆ value ┆ std   ┆ voxel_count │
│ ---        ┆ ---     ┆ ---       ┆ ---   ┆ ---   ┆ ---         │
│ str        ┆ str     ┆ str       ┆ f64   ┆ f64   ┆ i64         │
╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╡
│ 4213-a3_bl ┆ AF_L    ┆ afd_along ┆ 0.4   ┆ 0.21  ┆ 57334       │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ AF_L    ┆ radfODF   ┆ 0.08  ┆ 0.045 ┆ 57334       │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ AF_R    ┆ afd_along ┆ 0.42  ┆ 0.22  ┆ 53916       │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ AF_R    ┆ radfODF   ┆ 0.08  ┆ 0.04  ┆ 53916       │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ CC_1    ┆ afd_along ┆ null  ┆ null  ┆ null        │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ CC_1    ┆ radfODF   ┆ null  ┆ null  ┆ null        │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ CC_2a   ┆ afd_along ┆ 0.54  ┆ 0.3   ┆ 3264        │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 602620      │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4225-a3_bl ┆ CC_2a   ┆ radfODF   ┆ 0.06  ┆ 0.04  ┆ 3264        │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4225-a3_bl ┆ CC_2b   ┆ afd_along ┆ 0.47  ┆ 0.24  ┆ 18833       │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4225-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 718758      │
└────────────┴─────────┴───────────┴───────┴───────┴─────────────┘

First, we'll run the groupby statement to obtain our wm_mask__count values. I've changed your groupby to something that is more idiomatic of Polars.

mask_counts = (
    df
    .filter(pl.col('roi') == 'wm_mask')
    .groupby('sid')
    .agg([
        pl.col('voxel_count').first().alias('wm_mask__count')
    ])
)
mask_counts
shape: (2, 2)
┌────────────┬────────────────┐
│ sid        ┆ wm_mask__count │
│ ---        ┆ ---            │
│ str        ┆ i64            │
╞════════════╪════════════════╡
│ 4225-a3_bl ┆ 718758         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ 602620         │
└────────────┴────────────────┘

And then we'll use a "left" join to merge the result back into the original data:

df.join(
    mask_counts,
    on=['sid'],
    how='left',
)
shape: (11, 7)
┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┬────────────────┐
│ sid        ┆ roi     ┆ endpoint  ┆ value ┆ std   ┆ voxel_count ┆ wm_mask__count │
│ ---        ┆ ---     ┆ ---       ┆ ---   ┆ ---   ┆ ---         ┆ ---            │
│ str        ┆ str     ┆ str       ┆ f64   ┆ f64   ┆ i64         ┆ i64            │
╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╪════════════════╡
│ 4213-a3_bl ┆ AF_L    ┆ afd_along ┆ 0.4   ┆ 0.21  ┆ 57334       ┆ 602620         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ AF_L    ┆ radfODF   ┆ 0.08  ┆ 0.045 ┆ 57334       ┆ 602620         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ AF_R    ┆ afd_along ┆ 0.42  ┆ 0.22  ┆ 53916       ┆ 602620         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ AF_R    ┆ radfODF   ┆ 0.08  ┆ 0.04  ┆ 53916       ┆ 602620         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ CC_1    ┆ afd_along ┆ null  ┆ null  ┆ null        ┆ 602620         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ CC_1    ┆ radfODF   ┆ null  ┆ null  ┆ null        ┆ 602620         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ CC_2a   ┆ afd_along ┆ 0.54  ┆ 0.3   ┆ 3264        ┆ 602620         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4213-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 602620      ┆ 602620         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4225-a3_bl ┆ CC_2a   ┆ radfODF   ┆ 0.06  ┆ 0.04  ┆ 3264        ┆ 718758         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4225-a3_bl ┆ CC_2b   ┆ afd_along ┆ 0.47  ┆ 0.24  ┆ 18833       ┆ 718758         │
├╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ 4225-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 718758      ┆ 718758         │
└────────────┴─────────┴───────────┴───────┴───────┴─────────────┴────────────────┘

  • In Pandas, I believe merging datasets involves setting an index. (I'm not sure -- I've rarely used Pandas.) However, in Polars, there is no concept of an index. Instead, we use joins, where we simply name our join columns in the `on` statement of the `join`. –  Jun 09 '22 at 15:17
  • Wow, thank you! IIUC, polars "way of doing things" is more akin to SQL. This is excellent. – Nil Jun 09 '22 at 17:21
  • Pandas has (too) many features, both `join` and `merge`, and both(!) functions can use either index or column values to merge dataframes. – creanion Jun 13 '22 at 09:47
0

You can do this using a single polars window expression by combining polars.Expr.filter and polars.Expr.over

import polars as pl
from io import StringIO 

csv = StringIO(
"""sid,roi,endpoint,value,std,voxel_count
4213-a3_bl,AF_L,afd_along,0.40,0.21,57334
4213-a3_bl,AF_L,radfODF,0.08,0.045,57334
4213-a3_bl,AF_R,afd_along,0.42,0.22,53916
4213-a3_bl,AF_R,radfODF,0.08,0.04,53916
4213-a3_bl,CC_1,afd_along,,,
4213-a3_bl,CC_1,radfODF,,,
4213-a3_bl,CC_2a,afd_along,0.54,0.30,3264
4213-a3_bl,wm_mask,,,,602620\n""" # <--- added
"""4225-a3_bl,CC_2a,radfODF,0.06,0.04,3264
4225-a3_bl,CC_2b,afd_along,0.47,0.24,18833
4225-a3_bl,wm_mask,,,,718758""" # <--- added
)

df = pl.read_csv(csv)

res = df.with_columns(
    pl.col('voxel_count')
    .filter(pl.col('roi') == 'wm_mask')
    .first().over('sid')
    .alias('wm_mask__count')
)

Output: res

sid roi endpoint value std voxel_count wm_mask__count
4213-a3_bl AF_L afd_along 0.4 0.21 57334 602620
4213-a3_bl AF_L radfODF 0.08 0.045 57334 602620
4213-a3_bl AF_R afd_along 0.42 0.22 53916 602620
4213-a3_bl AF_R radfODF 0.08 0.04 53916 602620
4213-a3_bl CC_1 afd_along null null null 602620
4213-a3_bl CC_1 radfODF null null null 602620
4213-a3_bl CC_2a afd_along 0.54 0.3 3264 602620
4213-a3_bl wm_mask null null null 602620 602620
4225-a3_bl CC_2a radfODF 0.06 0.04 3264 718758
4225-a3_bl CC_2b afd_along 0.47 0.24 18833 718758
4225-a3_bl wm_mask null null null 718758 718758
Rodalm
  • 5,169
  • 5
  • 21