We'll adapt the algorithm from this Stack Overflow question. I say "adapt" because from the minimal problem you provided, it appears that the lists in col2
and col3
are not sets, in that some values are duplicated within a list. As such, we'll need to first remove the duplicates so that we are working with sets (as opposed to lists).
The algorithm does not require that the values be numbers. To demonstrate, let's convert col2
and col3
to strings.
import polars as pl
df = pl.DataFrame(
{
"col1": ["abc", "def", "ghi"],
"col2": [["cat", "dog", "mouse", "bird"], ["cat", "dog"], ["snail", "worm"]],
"col3": [["cat", "bird"], ["dog", "dog"], ["mouse", "bird", "starfish"]],
}
)
df
shape: (3, 3)
┌──────┬────────────────────────────┬───────────────────────────────┐
│ col1 ┆ col2 ┆ col3 │
│ --- ┆ --- ┆ --- │
│ str ┆ list[str] ┆ list[str] │
╞══════╪════════════════════════════╪═══════════════════════════════╡
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ def ┆ ["cat", "dog"] ┆ ["dog", "dog"] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] │
└──────┴────────────────────────────┴───────────────────────────────┘
Adapting from the algorithm in the Stack Overflow question mentioned above, we can calculate the intersection for your minimal example as follows:
df.with_column(
pl.col("col2")
.arr.unique()
.arr.concat(pl.col('col3').arr.unique())
.arr.eval(pl.element().filter(pl.element().is_duplicated()), parallel=True)
.arr.unique()
.alias('intersection')
)
shape: (3, 4)
┌──────┬────────────────────────────┬───────────────────────────────┬─────────────────┐
│ col1 ┆ col2 ┆ col3 ┆ intersection │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ list[str] ┆ list[str] ┆ list[str] │
╞══════╪════════════════════════════╪═══════════════════════════════╪═════════════════╡
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] ┆ ["cat", "bird"] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ def ┆ ["cat", "dog"] ┆ ["dog", "dog"] ┆ ["dog"] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] ┆ [] │
└──────┴────────────────────────────┴───────────────────────────────┴─────────────────┘
For a discussion of how the algorithm works, please see the above-mentioned Stack Overflow question.
Performance
To get some idea of performance, let's take our data and replicate it to about ~100M records. We'll use a cross join
to accomplish this easily.
df = pl.DataFrame(
{
"col1": ["abc", "def", "ghi"],
"col2": [["cat", "dog", "mouse", "bird"], ["cat", "dog"], ["snail", "worm"]],
"col3": [["cat", "bird"], ["dog", "dog"], ["mouse", "bird", "starfish"]],
}
)
nbr_groups = 34_000_000
df = (
df
.join(
pl.DataFrame({
'group': pl.arange(0, nbr_groups, eager=True)
}),
how="cross"
)
)
df
shape: (102000000, 4)
┌──────┬────────────────────────────┬───────────────────────────────┬──────────┐
│ col1 ┆ col2 ┆ col3 ┆ group │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ list[str] ┆ list[str] ┆ i64 │
╞══════╪════════════════════════════╪═══════════════════════════════╪══════════╡
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] ┆ 0 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] ┆ 1 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] ┆ 2 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] ┆ 3 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ ... ┆ ... ┆ ... ┆ ... │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] ┆ 33999996 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] ┆ 33999997 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] ┆ 33999998 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] ┆ 33999999 │
└──────┴────────────────────────────┴───────────────────────────────┴──────────┘
On my platform (a 32-core Threadripper Pro), I timed the algorithm using Python's time.perf_counter
. The algorithm took 163 seconds to complete for the 102M records.
import time
start = time.perf_counter()
df.with_column(
pl.col("col2")
.arr.unique()
.arr.concat(pl.col('col3').arr.unique())
.arr.eval(pl.element().filter(pl.element().is_duplicated()), parallel=True)
.arr.unique()
.alias('intersection')
)
print(time.perf_counter() - start)
shape: (102000000, 5)
┌──────┬────────────────────────────┬───────────────────────────────┬──────────┬─────────────────┐
│ col1 ┆ col2 ┆ col3 ┆ group ┆ intersection │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ list[str] ┆ list[str] ┆ i64 ┆ list[str] │
╞══════╪════════════════════════════╪═══════════════════════════════╪══════════╪═════════════════╡
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] ┆ 0 ┆ ["bird", "cat"] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] ┆ 1 ┆ ["bird", "cat"] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] ┆ 2 ┆ ["bird", "cat"] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ abc ┆ ["cat", "dog", ... "bird"] ┆ ["cat", "bird"] ┆ 3 ┆ ["cat", "bird"] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ ... ┆ ... ┆ ... ┆ ... ┆ ... │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] ┆ 33999996 ┆ [] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] ┆ 33999997 ┆ [] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] ┆ 33999998 ┆ [] │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ ghi ┆ ["snail", "worm"] ┆ ["mouse", "bird", "starfish"] ┆ 33999999 ┆ [] │
└──────┴────────────────────────────┴───────────────────────────────┴──────────┴─────────────────┘
>>> print(time.perf_counter() - start)
162.8689589149999
As you watch CPU usage, you'll note that the algorithm initially spawns only two threads. These two threads are presumably to de-dup col2
and col3
(the arr.unique
calls).
After both threads complete, you'll see the algorithm spread across all CPU cores, as it calculates the intersection in the arr.eval
step. (Note: the parallel=True
keyword is important.)
The algorithm then returns to a single thread to de-dup the intersection
column (a necessary step).
If your actual data does not contain duplicate values in the lists (in the equivalent of col2
and col3
), then you can save the initial de-dup process.
One further note: you can try to use the algorithm on the original strings versus the integer values that you created. You'll have to judge whether for your actual data and your computing platform whether any speed-up running the algorithm on numbers is worth the cost of converting the strings to numbers and then converting them back to strings after the algorithm.