0

In pl.join(), Is there any syntax sugar to "cleverly" cast the dtype of join cols. e.g. the higher granularity option or just take the dtypes from df1? could we add it as optional param to pl.join()?

e.g. int32 -> int64, datetime[ms] -> datetime[ns]

to avoid the dreaded: exceptions.ComputeError: datatypes of join keys don't match

user1441053
  • 680
  • 1
  • 5
  • 10
  • You could cast "manually" using df1's schema e.g. `on = "col1", "col2"; df1.join(on=on, other=df2.with_columns(pl.col(col).cast(df2.schema[col]) for col in on)` - not sure if there is any syntax sugar for it. – jqurious May 16 '23 at 12:39
  • This has been discussed on the Polars issue tracker before, and the stance is that there will not be automatic upcasting. See for example https://github.com/pola-rs/polars/issues/2815 – jvz May 20 '23 at 12:39

1 Answers1

0

You can make your own then monkey patch it to pl.DataFrame

This is only handles floats and ints but you can build off of it and improve it. It has ample room for improvement

def myjoin(self, 
        other: pl.DataFrame, 
        on: str | pl.Expr | None = None, 
        how: pl.type_aliases.JoinStrategy = 'inner', 
        left_on: str | pl.Expr | None = None, 
        right_on: str | pl.Expr | None = None, 
        suffix: str = '_right'):
    if left_on is None and right_on is None and on is not None:
        left_on=on
        right_on=on
    elif on is None and left_on is not None and right_on is not None:
        pass
        #should check for other consistency (len etc)
    else:
        raise ValueError("inconsistent right_on, left_on, on")
    if isinstance(left_on, str):
        left_on=[left_on]
    if isinstance(right_on, str):
        right_on=[right_on]
    for i, col in enumerate(left_on):
        if self.schema[col]!=other.schema[col]:
            if self.schema[col] in pl.datatypes.INTEGER_DTYPES and other.schema[col] in pl.datatypes.INTEGER_DTYPES:
                self=self.with_columns(pl.col(col).cast(pl.Int64()))
                other=other.with_columns(pl.col(right_on[i]).cast(pl.Int64()))
            elif self.schema[col] in pl.datatypes.FLOAT_DTYPES and other.schema[col] in pl.datatypes.FLOAT_DTYPES:
                self=self.with_columns(pl.col(col).cast(pl.Float64()))
                other=other.with_columns(pl.col(right_on[i]).cast(pl.Float64()))
            else:
                raise ValueError("only floats and ints are upgraded, need to add TEMPORAL and other logic")
    return self.join(other, left_on=left_on, right_on=right_on, suffix=suffix, how=how)
pl.DataFrame.myjoin=myjoin

Then, if you have

df=pl.DataFrame({'a':[1,2,3], 'b':[2,3,4]}).with_columns(a=pl.col('a').cast(pl.Int8()))
df2=pl.DataFrame({'a':[1,2,3], 'c':[3,4,5]}).with_columns(a=pl.col('a').cast(pl.Int16()))

You can do

df.myjoin(df2, on='a')

shape: (3, 3)
┌─────┬─────┬─────┐
│ a   ┆ b   ┆ c   │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 2   ┆ 3   │
│ 2   ┆ 3   ┆ 4   │
│ 3   ┆ 4   ┆ 5   │
└─────┴─────┴─────┘

I only made it check for Floats and Ints and it just goes straight to the 64bit variety rather than trying to determine which one of the two needs casting and only casting that one. It also doesn't cast Ints to Floats but you could add that logic. It's probably better to make each of self and other lazy before the for loop that casts the join columns. I also didn't attempt the datetime conversions but it should just be tedium to add it at this point.

Dean MacGregor
  • 11,847
  • 9
  • 34
  • 72