Is there a better way to handle set-like operations with Polars?
A problem: I have a data with connected ids, where I want to group all ids together and attach to every group of ids some unique id, in this simple case, just a positive integer.
This is an example of data, I'm looking into:
data = {"ConnectedIDs": [[1, 2], [6], [8], [2], [7], [2, 7], [7, 9], [6, 8]]}
df = pl.DataFrame(data)
I have a dataframe with connected ids, which are a lists. I would like to find and assign to all rows the unified id. In this case, if a row has [1, 2]
and [2, 7]
, means both rows should have the same id, because they are interconnected. This happens with a row of [7, 9]
, because they have the same shared id with [2, 7]
. But, the rows [6]
, [8]
and [6, 8]
are having a different id.
For this dataframe, I expect to get something like this in the end(order do not matters):
data = {"ConnectedID": [1, 2, 6, 8, 7, 9], "UnitedID": [0, 0, 1, 1, 0, 0]}
df = pl.DataFrame(data)
Here I have a simplified version of a code, which does exactly this using polars:
def compute_united_ids_pure_polars(connected_df, previous=None):
computed = (
connected_df.with_columns(col("ConnectedIDs").alias("ConnectedIDs2"))
.explode("ConnectedIDs")
.groupby("ConnectedIDs")
.agg([col("ConnectedIDs2").flatten().unique()])
.select(col("ConnectedIDs2").alias("ConnectedIDs"))
.unique()
)
count = computed.select(pl.count())["count"][0]
if previous is None or count < previous:
return compute_united_ids_pure_polars(computed, count)
else:
return computed.with_columns(
pl.arange(0, pl.count()).alias("UnitedID")
).explode("ConnectedIDs")
I could make it work for this simple set, but it do not finishes for a small real set, probably it is too exploisive to explode lists so much times, but from polars API, I do not see any better approach.
But I have the same algorithm with Python functions implemented, which is pretty fast:
def compute_united_ids_tree_sets(connected_df):
parent_ids_dict = dict()
for idx, row in enumerate(connected_df.iter_rows(named=True)):
connected_ids = row["ConnectedIDs"]
opt_add_row_ids_to_sets(parent_ids_dict, connected_ids)
connected_ids_dict = build_connected_ids_dict(parent_ids_dict)
return pl.DataFrame(
{
"ConnectedID": connected_ids_dict.keys(),
"UnitedID": connected_ids_dict.values(),
}
)
def opt_add_row_ids_to_sets(parent_ids_dict, connected_ids):
parent_ids = []
parent_set = set()
for connected_id in connected_ids:
(_, parent_id) = parent_tuple = get_parent(parent_ids_dict, connected_id)
parent_ids.append(parent_tuple)
parent_set.add(parent_id)
parents_len = len(parent_set)
if parents_len == 1:
# Nothing to merge
pass
else:
# Merge tree sets
_count, new_parent_id = max(parent_ids, key=lambda x: x[0])
parent_set.remove(new_parent_id)
for parent_id in parent_set:
parent_ids_dict[parent_id] = new_parent_id
def get_parent(parent_ids_dict, connected_id, count=0):
parent_id = parent_ids_dict.get(connected_id)
if parent_id == None:
parent_ids_dict[connected_id] = connected_id
return (count, connected_id)
elif parent_id == connected_id:
return (count + 1, connected_id)
else:
return get_parent(parent_ids_dict, parent_id, count + 1)
def build_connected_ids_dict(parent_ids_dict):
connected_ids_dict = {}
for connected_id in parent_ids_dict.keys():
_count, parent_id = get_parent(parent_ids_dict, connected_id)
connected_ids_dict[connected_id] = parent_id
return connected_ids_dict
compute_united_ids(df)
It of course uses iter_rows
with a custom Python function to compute this.
This is optimized version, but basically, what it does - it fetches the parent of multiple ids - and if for the set of ids in the same row - there are multiple parents - it brings all parents to have the same id as a parent.
I have a question, are there any better way to solve something complex like this with polars? Or this is a perfect case to use Python functions and for optimisation and concurrency, I can go only down to Rust?
P.S.:
Performance-wise, I measured only how long computed
was executed:
compute_united_ids_pure_polars(df_real)
Elapsed time: 0.1691 seconds
Elapsed time: 1.1517 seconds
And afterwards it never finished.
The Python iterative version full run:
compute_united_ids_merge_trees(df_real)
Elapsed time: 0.1157 seconds
It seems, that exploiding lists is a wrong approach for this.