I am trying to find common elements in a column of list wrt a reference cell. I could accomplish it with a small dataset but I face two problems. The speed is excruciatingly slow even for 25 rows of sample data (20.7 s ± 52 ms per loop), and I unable to find a faster implementation through map
which can use parallelization unlike apply
that works on a single thread.
The version that I have working right now is as follows:
>>> import polars as pl
>>> import numpy as np
>>>
>>> df = pl.DataFrame({'animal': ['goat','tiger','goat','tiger','lion','goat','tiger','lion'], 'food': ['grass','rabbit','carrots','deer','zebra','water','water','water']})
>>> dl = df.groupby('animal').agg_list()
>>> dl
shape: (3, 2)
┌────────┬───────────────────────────────┐
│ animal ┆ food │
│ --- ┆ --- │
│ str ┆ list[str] │
╞════════╪═══════════════════════════════╡
│ lion ┆ ["zebra", "water"] │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ tiger ┆ ["rabbit", "deer", "water"] │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ goat ┆ ["grass", "carrots", "water"] │
└────────┴───────────────────────────────┘
>>>
>>> refn = dl['food'][1].to_numpy()
>>> dl['food'] = dl['food'].apply(lambda x: np.intersect1d(refn,x.to_numpy()))
>>> dl
shape: (3, 2)
┌────────┬───────────────────────────┐
│ animal ┆ food │
│ --- ┆ --- │
│ str ┆ object │
╞════════╪═══════════════════════════╡
│ lion ┆ ['water'] │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ tiger ┆ ['deer' 'rabbit' 'water'] │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ goat ┆ ['water'] │
└────────┴───────────────────────────┘
>>>
Any help will be greatly appreciated. TIA.