Search code examples
python-polarsrust-polars

Polars Join cleverly infer appropriate dtypes of join key?


In pl.join(), Is there any syntax sugar to "cleverly" cast the dtype of join cols. e.g. the higher granularity option or just take the dtypes from df1? could we add it as optional param to pl.join()?

e.g. int32 -> int64, datetime[ms] -> datetime[ns]

to avoid the dreaded: exceptions.ComputeError: datatypes of join keys don't match


Solution

  • You can make your own then monkey patch it to pl.DataFrame

    This is only handles floats and ints but you can build off of it and improve it. It has ample room for improvement

    def myjoin(self, 
            other: pl.DataFrame, 
            on: str | pl.Expr | None = None, 
            how: pl.type_aliases.JoinStrategy = 'inner', 
            left_on: str | pl.Expr | None = None, 
            right_on: str | pl.Expr | None = None, 
            suffix: str = '_right'):
        if left_on is None and right_on is None and on is not None:
            left_on=on
            right_on=on
        elif on is None and left_on is not None and right_on is not None:
            pass
            #should check for other consistency (len etc)
        else:
            raise ValueError("inconsistent right_on, left_on, on")
        if isinstance(left_on, str):
            left_on=[left_on]
        if isinstance(right_on, str):
            right_on=[right_on]
        for i, col in enumerate(left_on):
            if self.schema[col]!=other.schema[col]:
                if self.schema[col] in pl.datatypes.INTEGER_DTYPES and other.schema[col] in pl.datatypes.INTEGER_DTYPES:
                    self=self.with_columns(pl.col(col).cast(pl.Int64()))
                    other=other.with_columns(pl.col(right_on[i]).cast(pl.Int64()))
                elif self.schema[col] in pl.datatypes.FLOAT_DTYPES and other.schema[col] in pl.datatypes.FLOAT_DTYPES:
                    self=self.with_columns(pl.col(col).cast(pl.Float64()))
                    other=other.with_columns(pl.col(right_on[i]).cast(pl.Float64()))
                else:
                    raise ValueError("only floats and ints are upgraded, need to add TEMPORAL and other logic")
        return self.join(other, left_on=left_on, right_on=right_on, suffix=suffix, how=how)
    pl.DataFrame.myjoin=myjoin
    

    Then, if you have

    df=pl.DataFrame({'a':[1,2,3], 'b':[2,3,4]}).with_columns(a=pl.col('a').cast(pl.Int8()))
    df2=pl.DataFrame({'a':[1,2,3], 'c':[3,4,5]}).with_columns(a=pl.col('a').cast(pl.Int16()))
    

    You can do

    df.myjoin(df2, on='a')
    
    shape: (3, 3)
    ┌─────┬─────┬─────┐
    │ a   ┆ b   ┆ c   │
    │ --- ┆ --- ┆ --- │
    │ i64 ┆ i64 ┆ i64 │
    ╞═════╪═════╪═════╡
    │ 1   ┆ 2   ┆ 3   │
    │ 2   ┆ 3   ┆ 4   │
    │ 3   ┆ 4   ┆ 5   │
    └─────┴─────┴─────┘
    

    I only made it check for Floats and Ints and it just goes straight to the 64bit variety rather than trying to determine which one of the two needs casting and only casting that one. It also doesn't cast Ints to Floats but you could add that logic. It's probably better to make each of self and other lazy before the for loop that casts the join columns. I also didn't attempt the datetime conversions but it should just be tedium to add it at this point.