Search code examples
pythonpython-polars

Find intersection of columns from different polars dataframes


I have a variable number of pl.DataFrames which share some columns (e.g. symbol and date). Each pl.DataFrame has a number of additional columns, which are not important for the actual task.

The symbol columns do have exactly the same content (the different str values exist in every dataframe). The date columns are somewhat different in the way that they don't have the exact same dates in every pl.DataFrame.

The actual task is to find the common dates per grouping (i.e. symbol) and filter each pl.DataFrame accordingly.

Here are three example pl.DataFrames:

import polars as pl

df1 = pl.DataFrame(
    {
        "symbol": ["AAPL"] * 4 + ["GOOGL"] * 3,
        "date": [
            "2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04",
            "2023-01-02", "2023-01-03", "2023-01-04",
        ],
        "some_other_col": range(7),
    }
)

df2 = pl.DataFrame(
    {
        "symbol": ["AAPL"] * 3 + ["GOOGL"] * 5,
        "date": [
            "2023-01-02", "2023-01-03", "2023-01-04",
            "2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04", "2023-01-05",
        ],
        "another_col": range(8),
    }
)

df3 = pl.DataFrame(
    {
        "symbol": ["AAPL"] * 4 + ["GOOGL"] * 2,
        "date": [
            "2023-01-02", "2023-01-03", "2023-01-04", "2023-01-05",
            "2023-01-03", "2023-01-04",
        ],
        "some_col": range(6),
    }
)
DataFrame 1:
shape: (7, 3)
┌────────┬────────────┬────────────────┐
│ symbol ┆ date       ┆ some_other_col │
│ ---    ┆ ---        ┆ ---            │
│ str    ┆ str        ┆ i64            │
╞════════╪════════════╪════════════════╡
│ AAPL   ┆ 2023-01-01 ┆ 0              │
│ AAPL   ┆ 2023-01-02 ┆ 1              │
│ AAPL   ┆ 2023-01-03 ┆ 2              │
│ AAPL   ┆ 2023-01-04 ┆ 3              │
│ GOOGL  ┆ 2023-01-02 ┆ 4              │
│ GOOGL  ┆ 2023-01-03 ┆ 5              │
│ GOOGL  ┆ 2023-01-04 ┆ 6              │
└────────┴────────────┴────────────────┘

DataFrame 2:
shape: (8, 3)
┌────────┬────────────┬─────────────┐
│ symbol ┆ date       ┆ another_col │
│ ---    ┆ ---        ┆ ---         │
│ str    ┆ str        ┆ i64         │
╞════════╪════════════╪═════════════╡
│ AAPL   ┆ 2023-01-02 ┆ 0           │
│ AAPL   ┆ 2023-01-03 ┆ 1           │
│ AAPL   ┆ 2023-01-04 ┆ 2           │
│ GOOGL  ┆ 2023-01-01 ┆ 3           │
│ GOOGL  ┆ 2023-01-02 ┆ 4           │
│ GOOGL  ┆ 2023-01-03 ┆ 5           │
│ GOOGL  ┆ 2023-01-04 ┆ 6           │
│ GOOGL  ┆ 2023-01-05 ┆ 7           │
└────────┴────────────┴─────────────┘

DataFrame 3:
shape: (6, 3)
┌────────┬────────────┬──────────┐
│ symbol ┆ date       ┆ some_col │
│ ---    ┆ ---        ┆ ---      │
│ str    ┆ str        ┆ i64      │
╞════════╪════════════╪══════════╡
│ AAPL   ┆ 2023-01-02 ┆ 0        │
│ AAPL   ┆ 2023-01-03 ┆ 1        │
│ AAPL   ┆ 2023-01-04 ┆ 2        │
│ AAPL   ┆ 2023-01-05 ┆ 3        │
│ GOOGL  ┆ 2023-01-03 ┆ 4        │
│ GOOGL  ┆ 2023-01-04 ┆ 5        │
└────────┴────────────┴──────────┘

Now, the first step would be to find the common dates for every symbol.
AAPL: ["2023-01-02", "2023-01-03", "2023-01-04"]
GOOGL: ["2023-01-03", "2023-01-04"]

That means, each pl.DataFrame needs to be filtered accordingly. The expected result looks like this:

DataFrame 1 filtered:

shape: (5, 3)
┌────────┬────────────┬────────────────┐
│ symbol ┆ date       ┆ some_other_col │
│ ---    ┆ ---        ┆ ---            │
│ str    ┆ str        ┆ i64            │
╞════════╪════════════╪════════════════╡
│ AAPL   ┆ 2023-01-02 ┆ 1              │
│ AAPL   ┆ 2023-01-03 ┆ 2              │
│ AAPL   ┆ 2023-01-04 ┆ 3              │
│ GOOGL  ┆ 2023-01-03 ┆ 5              │
│ GOOGL  ┆ 2023-01-04 ┆ 6              │
└────────┴────────────┴────────────────┘

DataFrame 2 filtered:
shape: (5, 3)
┌────────┬────────────┬─────────────┐
│ symbol ┆ date       ┆ another_col │
│ ---    ┆ ---        ┆ ---         │
│ str    ┆ str        ┆ i64         │
╞════════╪════════════╪═════════════╡
│ AAPL   ┆ 2023-01-02 ┆ 0           │
│ AAPL   ┆ 2023-01-03 ┆ 1           │
│ AAPL   ┆ 2023-01-04 ┆ 2           │
│ GOOGL  ┆ 2023-01-03 ┆ 5           │
│ GOOGL  ┆ 2023-01-04 ┆ 6           │
└────────┴────────────┴─────────────┘

DataFrame 3 filtered:
shape: (5, 3)
┌────────┬────────────┬──────────┐
│ symbol ┆ date       ┆ some_col │
│ ---    ┆ ---        ┆ ---      │
│ str    ┆ str        ┆ i64      │
╞════════╪════════════╪══════════╡
│ AAPL   ┆ 2023-01-02 ┆ 0        │
│ AAPL   ┆ 2023-01-03 ┆ 1        │
│ AAPL   ┆ 2023-01-04 ┆ 2        │
│ GOOGL  ┆ 2023-01-03 ┆ 4        │
│ GOOGL  ┆ 2023-01-04 ┆ 5        │
└────────┴────────────┴──────────┘

Solution

  • You can use pl.DataFrame.join() with how="semi" parameter:

    semi Returns rows from the left table that have a match in the right table.

    on = ["symbol","date"]
    df1.join(df2, on=on, how="semi").join(df3, on=on, how="semi")
    df2.join(df1, on=on, how="semi").join(df3, on=on, how="semi")
    df3.join(df1, on=on, how="semi").join(df2, on=on, how="semi")
    
    shape: (5, 3)
    ┌────────┬────────────┬────────────────┐
    │ symbol ┆ date       ┆ some_other_col │
    │ ---    ┆ ---        ┆ ---            │
    │ str    ┆ str        ┆ i64            │
    ╞════════╪════════════╪════════════════╡
    │ AAPL   ┆ 2023-01-02 ┆ 1              │
    │ AAPL   ┆ 2023-01-03 ┆ 2              │
    │ AAPL   ┆ 2023-01-04 ┆ 3              │
    │ GOOGL  ┆ 2023-01-03 ┆ 5              │
    │ GOOGL  ┆ 2023-01-04 ┆ 6              │
    └────────┴────────────┴────────────────┘
    shape: (5, 3)
    ┌────────┬────────────┬─────────────┐
    │ symbol ┆ date       ┆ another_col │
    │ ---    ┆ ---        ┆ ---         │
    │ str    ┆ str        ┆ i64         │
    ╞════════╪════════════╪═════════════╡
    │ AAPL   ┆ 2023-01-02 ┆ 0           │
    │ AAPL   ┆ 2023-01-03 ┆ 1           │
    │ AAPL   ┆ 2023-01-04 ┆ 2           │
    │ GOOGL  ┆ 2023-01-03 ┆ 5           │
    │ GOOGL  ┆ 2023-01-04 ┆ 6           │
    └────────┴────────────┴─────────────┘
    shape: (5, 3)
    ┌────────┬────────────┬──────────┐
    │ symbol ┆ date       ┆ some_col │
    │ ---    ┆ ---        ┆ ---      │
    │ str    ┆ str        ┆ i64      │
    ╞════════╪════════════╪══════════╡
    │ AAPL   ┆ 2023-01-02 ┆ 0        │
    │ AAPL   ┆ 2023-01-03 ┆ 1        │
    │ AAPL   ┆ 2023-01-04 ┆ 2        │
    │ GOOGL  ┆ 2023-01-03 ┆ 4        │
    │ GOOGL  ┆ 2023-01-04 ┆ 5        │
    └────────┴────────────┴──────────┘
    

    Or you could probably generalize it a bit:

    on = ["symbol","date"]
    dfs = [df1, df2, df3]
    
    # filter first dataframe on all others
    for df in dfs[1:]:
        dfs[0] = dfs[0].join(df, on=on, how="semi")
    
    # then filter all others on first one
    for i, df in enumerate(dfs[1:]):
        dfs[i] = df.join(dfs[0], on=on, how="semi")
    
    for df in dfs:
        print(df)