2

How do I find the set intersection of a column of lists?

[dependencies]
polars = { version = "*", features = ["lazy"] }
use polars::df;
use polars::prelude::*;

fn main() {
    let df = df![
        "bar" => ["a", "b", "c", "a", "b", "c", "a", "c"],
        "ham" => ["foo", "foo", "foo", "bar", "bar", "bar", "bing", "bang"]
    ]
    .unwrap();

    let df_grp = df
        .lazy()
        .groupby(["bar"])
        .agg([col("ham").list()])
        .collect()
        .unwrap();

    println!("{:?}", df_grp);
}

prints:

┌─────┬────────────────────────┐
│ bar ┆ ham                    │
│ --- ┆ ---                    │
│ str ┆ list[str]              │
╞═════╪════════════════════════╡
│ c   ┆ ["foo", "bar", "bang"] │
├╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ b   ┆ ["foo", "bar"]         │
├╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ a   ┆ ["foo", "bar", "bing"] │
└─────┴────────────────────────┘

What I would like to do is do a set intersection of rows a/b/c ⇒ ["foo","bar"] as the common strings in all rows.

My though was to turn the column of lists of string to a column of hashsets and then fold/reduce the intersection. How do I go from Series<list<String>>Series<HashSet>? If this is possible in a lazyframe fold expression, that would be great but how to define the accumulator? lit(HashSet)?

E_net4
  • 27,810
  • 13
  • 101
  • 139
jharting
  • 71
  • 3

1 Answers1

0

I found a way to do this, though not with expressions.

use polars::df;
use polars::prelude::*;
use std::collections::HashSet;

fn main() -> Result<(), PolarsError> {
    let df = df![
        "bar" => ["a", "b", "c",
                  "a","b","c",
                  "a","b","c"],
        "ham" => ["foo", "foo","foo",
                  "bar", "bar","bar", 
                  "bing", "bang","bing"]
    ]
    .unwrap();

let df_grp = df
    .lazy()
    .groupby(["bar"])
    .agg([col("ham").list()])
    .sort("bar", Default::default())
    .collect()?;

println!("{:?}", df_grp);

let mut s_sets: Vec<Vec<String>> = Vec::new();
df_grp
    .column("ham")?
    .list()?
    .into_iter()
    .for_each(|opt_lst| match opt_lst {
        None => s_sets.push(vec!["".to_string()]),
        Some(lst) => s_sets.push(
            lst.clone()
                .utf8()
                .unwrap()
                .into_no_null_iter()
                .map(|s: &str| s.to_string())
                .collect::<Vec<String>>(),
        ),
    });

    let common = find_common_ids(s_sets);
    println!("{:?}", common);

    Ok(())
}

fn find_common_ids(callset: Vec<Vec<String>>) -> HashSet<String> {
    let init = HashSet::from_iter(callset[0].iter().cloned());
    callset[1..].iter().fold(init, |common, new| {
        let new = HashSet::from_iter(new.iter().cloned());
        &common & &new
    })
}
jharting
  • 71
  • 3