2

Say I have this:

df = polars.DataFrame(dict(
  j=[1,2,3],
  k=[4,5,6],
  l=[7,8,9],
  ))

shape: (3, 3)
┌─────┬─────┬─────┐
│ j   ┆ k   ┆ l   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 4   ┆ 7   │
│ 2   ┆ 5   ┆ 8   │
│ 3   ┆ 6   ┆ 9   │
└─────┴─────┴─────┘

I can filter for a particular row doing it one column at at time, i.e.:

df = df.filter(
  (polars.col('j') == 2) &
  (polars.col('k') == 5) &
  (polars.col('l') == 8)
  )

shape: (1, 3)
┌─────┬─────┬─────┐
│ j   ┆ k   ┆ l   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 2   ┆ 5   ┆ 8   │
└─────┴─────┴─────┘

I'd like to compare to the list instead though (so I can avoid listing each column and to accommodate variable column DataFrames), e.g. something like:

df = df.filter(
    polars.concat_list(polars.all()) == [2, 5, 8]
    )

...
exceptions.ArrowErrorException: NotYetImplemented("Casting from Int64 to LargeList(Field { name: \"item\", data_type: Int64, is_nullable: true, metadata: {} }) not supported")

Any ideas why the above is throwing the exception?

I can build the expression manually:

df = df.filter(
  functools.reduce(lambda a, e: a & e, (polars.col(c) == v for c, v in zip(df.columns, [2, 5, 8])))
  )

but I was hoping there's a way to compare lists directly - e.g. as if I had this DataFrame originally:

df = polars.DataFrame(dict(j=[
  [1,4,7],
  [2,5,8],
  [3,6,9],
  ]))

shape: (3, 1)
┌───────────┐
│ j         │
│ ---       │
│ list[i64] │
╞═══════════╡
│ [1, 4, 7] │
│ [2, 5, 8] │
│ [3, 6, 9] │
└───────────┘

and wanted to find the row which matches [2, 5, 8]. Any hints?

levant pied
  • 3,886
  • 5
  • 37
  • 56

2 Answers2

2

You can pass multiple conditions to .all() instead of functools.reduce

For a list column, you can compare the values at each index with .arr.get():

df.filter(
   pl.all(
      pl.col("j").arr.get(n) == row[n]
      for n in range(len(row))
      for row in [[2, 5, 8]]
   )
)
shape: (1, 1)
┌───────────┐
│ j         │
│ ---       │
│ list[i64] │
╞═══════════╡
│ [2, 5, 8] │
└───────────┘

I'm not sure why this doesn't work:

>>> df.filter(pl.col("j") == pl.lit([[2, 5, 8]]))
shape: (0, 1)
┌───────────┐
│ j         │
│ ---       │
│ list[i64] │
╞═══════════╡
└───────────┘

For regular columns, you could modify your example:

df.filter(
   pl.all(
      pl.col(col) == value
      for col, value in dict(zip(df.columns, [2, 5, 8])).items()
   )
)
jqurious
  • 9,953
  • 1
  • 4
  • 14
1

Here's a lightweight approach to solving this. It's possible to achieve the desired behaviour by filtering a DataFrame using a list comprehension which compares each row to the target row. To iterate over each row we can leverage the df.rows() method which returns a list of tuples, where each tuple represents a single row of the DataFrame. Therefore, we'll specify target row as a tuple as well. I enlarged your example DataFrame to 3 million rows to make sure I recommend you a solution which is sufficiently fast.

tar_row = (2,5,8)
df = df.filter([row == tar_row for row in df.rows()])

Full code:

import polars

factor = 1_000_000
# instantiate a polars DataFrame with 3,000,000 rows
df = polars.DataFrame(dict(
  j=[1,2,3]*factor,
  k=[4,5,6]*factor,
  l=[7,8,9]*factor,
  ))
print(df.shape)  # (3000000, 3)

# filter df based on target values
tar_row = (2,5,8)

df = df.filter([row == tar_row for row in df.rows()])
print(df.shape)  # (1000000, 3)

I ran this code on Google Colab and it needed 1 second to compute. For 30 million rows it needed 5 seconds, and for 100 million rows 40 seconds. Thus, while this code might not be the fastest approach on earth, it's nevertheless very fast given it's very lightweight approach. If you're using a very big DataFrame, I would recommend you constructing a filter "the traditional way" as a fully vectorized approach is always faster.

Simon David
  • 663
  • 3
  • 13