1

Is there a better way to handle set-like operations with Polars?

A problem: I have a data with connected ids, where I want to group all ids together and attach to every group of ids some unique id, in this simple case, just a positive integer.

This is an example of data, I'm looking into:

data = {"ConnectedIDs": [[1, 2], [6], [8], [2], [7], [2, 7], [7, 9], [6, 8]]}
df = pl.DataFrame(data)

I have a dataframe with connected ids, which are a lists. I would like to find and assign to all rows the unified id. In this case, if a row has [1, 2] and [2, 7], means both rows should have the same id, because they are interconnected. This happens with a row of [7, 9], because they have the same shared id with [2, 7]. But, the rows [6], [8] and [6, 8] are having a different id.

For this dataframe, I expect to get something like this in the end(order do not matters):

data = {"ConnectedID": [1, 2, 6, 8, 7, 9], "UnitedID": [0, 0, 1, 1, 0, 0]}
df = pl.DataFrame(data)

Here I have a simplified version of a code, which does exactly this using polars:

def compute_united_ids_pure_polars(connected_df, previous=None):
    computed = (
        connected_df.with_columns(col("ConnectedIDs").alias("ConnectedIDs2"))
        .explode("ConnectedIDs")
        .groupby("ConnectedIDs")
        .agg([col("ConnectedIDs2").flatten().unique()])
        .select(col("ConnectedIDs2").alias("ConnectedIDs"))
        .unique()
    )

    count = computed.select(pl.count())["count"][0]

    if previous is None or count < previous:
        return compute_united_ids_pure_polars(computed, count)
    else:
        return computed.with_columns(
            pl.arange(0, pl.count()).alias("UnitedID")
        ).explode("ConnectedIDs")

I could make it work for this simple set, but it do not finishes for a small real set, probably it is too exploisive to explode lists so much times, but from polars API, I do not see any better approach.

But I have the same algorithm with Python functions implemented, which is pretty fast:

def compute_united_ids_tree_sets(connected_df):
    parent_ids_dict = dict()

    for idx, row in enumerate(connected_df.iter_rows(named=True)):
        connected_ids = row["ConnectedIDs"]
        opt_add_row_ids_to_sets(parent_ids_dict, connected_ids)

    connected_ids_dict = build_connected_ids_dict(parent_ids_dict)

    return pl.DataFrame(
        {
            "ConnectedID": connected_ids_dict.keys(),
            "UnitedID": connected_ids_dict.values(),
        }
    )


def opt_add_row_ids_to_sets(parent_ids_dict, connected_ids):
    parent_ids = []
    parent_set = set()

    for connected_id in connected_ids:
        (_, parent_id) = parent_tuple = get_parent(parent_ids_dict, connected_id)
        parent_ids.append(parent_tuple)
        parent_set.add(parent_id)

    parents_len = len(parent_set)
    if parents_len == 1:
        # Nothing to merge
        pass
    else:
        # Merge tree sets
        _count, new_parent_id = max(parent_ids, key=lambda x: x[0])
        parent_set.remove(new_parent_id)
        for parent_id in parent_set:
            parent_ids_dict[parent_id] = new_parent_id


def get_parent(parent_ids_dict, connected_id, count=0):
    parent_id = parent_ids_dict.get(connected_id)
    if parent_id == None:
        parent_ids_dict[connected_id] = connected_id
        return (count, connected_id)
    elif parent_id == connected_id:
        return (count + 1, connected_id)
    else:
        return get_parent(parent_ids_dict, parent_id, count + 1)


def build_connected_ids_dict(parent_ids_dict):
    connected_ids_dict = {}
    for connected_id in parent_ids_dict.keys():
        _count, parent_id = get_parent(parent_ids_dict, connected_id)
        connected_ids_dict[connected_id] = parent_id

    return connected_ids_dict

compute_united_ids(df)

It of course uses iter_rows with a custom Python function to compute this. This is optimized version, but basically, what it does - it fetches the parent of multiple ids - and if for the set of ids in the same row - there are multiple parents - it brings all parents to have the same id as a parent.

I have a question, are there any better way to solve something complex like this with polars? Or this is a perfect case to use Python functions and for optimisation and concurrency, I can go only down to Rust?

P.S.:

Performance-wise, I measured only how long computed was executed:

compute_united_ids_pure_polars(df_real)
Elapsed time: 0.1691 seconds
Elapsed time: 1.1517 seconds

And afterwards it never finished.

The Python iterative version full run:

compute_united_ids_merge_trees(df_real)
Elapsed time: 0.1157 seconds

It seems, that exploiding lists is a wrong approach for this.

  • Can you show the code for `def compute_united_ids_pure_polars_inner`? – jqurious Jun 19 '23 at 16:25
  • @jqurious Updated, it was recursive call, I just forgot to update both names. compute_united_ids_pure_polars_inner = compute_united_ids_pure_polars – Dmitry Russ Jul 11 '23 at 09:19

0 Answers0