Search code examples
pythonpython-polars

Correlation matrix like DataFrame in Polars


I have Polars dataframe

data = {
    "col1": ["a", "b", "c", "d"],
    "col2": [[-0.06066, 0.072485, 0.548874, 0.158507],
             [-0.536674, 0.10478, 0.926022, -0.083722],
             [-0.21311, -0.030623, 0.300583, 0.261814],
             [-0.308025, 0.006694, 0.176335, 0.533835]],
}

df = pl.DataFrame(data)

I want to calculate cosine similarity for each combination of column col1

The desired output should be the following:

┌─────────────────┬──────┬──────┬──────┬──────┐
│ col1_col2       ┆ a    ┆ b    ┆ c    ┆ d    │
│ ---             ┆ ---  ┆ ---  ┆ ---  ┆ ---  │
│ str             ┆ f64  ┆ f64  ┆ f64  ┆ f64  │
╞═════════════════╪══════╪══════╪══════╪══════╡
│ a               ┆ 1.0  ┆ 0.86 ┆ 0.83 ┆ 0.54 │
│ b               ┆ 0.86 ┆ 1.0  ┆ 0.75 ┆ 0.41 │
│ c               ┆ 0.83 ┆ 0.75 ┆ 1.0  ┆ 0.89 │
│ d               ┆ 0.54 ┆ 0.41 ┆ 0.89 ┆ 1.0  │
└─────────────────┴──────┴──────┴──────┴──────┘

Where each value represents cosine similarity between respective column values.

I'm using following cosine similarity function

from numpy.linalg import norm

cosine_similarity = lambda a,b: (a @ b.T) / (norm(a)*norm(b))

I tried to use it with pivot method

df.pivot(on="col1", values="col2", index="col1", aggregate_function=cosine_similarity)

However I'm getting the following error

AttributeError: 'function' object has no attribute '_pyexpr'

Solution

  • Update: Polars 1.8.0 added native list arithmetic allowing us to write a much more efficient cosine similarity expression.


    Combinations

    We can add a row index and use .join_where() to generate the row "combinations".

    df = df.with_row_index().lazy()
    
    df.join_where(df, pl.col.index <= pl.col.index_right).collect() 
    
    shape: (10, 6)
    ┌───────┬──────┬─────────────────────────────────┬─────────────┬────────────┬─────────────────────────────────┐
    │ index ┆ col1 ┆ col2                            ┆ index_right ┆ col1_right ┆ col2_right                      │
    │ ---   ┆ ---  ┆ ---                             ┆ ---         ┆ ---        ┆ ---                             │
    │ u32   ┆ str  ┆ list[f64]                       ┆ u32         ┆ str        ┆ list[f64]                       │
    ╞═══════╪══════╪═════════════════════════════════╪═════════════╪════════════╪═════════════════════════════════╡
    │ 0     ┆ a    ┆ [-0.06066, 0.072485, … 0.15850… ┆ 0           ┆ a          ┆ [-0.06066, 0.072485, … 0.15850… │
    │ 0     ┆ a    ┆ [-0.06066, 0.072485, … 0.15850… ┆ 1           ┆ b          ┆ [-0.536674, 0.10478, … -0.0837… │
    │ 0     ┆ a    ┆ [-0.06066, 0.072485, … 0.15850… ┆ 2           ┆ c          ┆ [-0.21311, -0.030623, … 0.2618… │
    │ 0     ┆ a    ┆ [-0.06066, 0.072485, … 0.15850… ┆ 3           ┆ d          ┆ [-0.308025, 0.006694, … 0.5338… │
    │ 1     ┆ b    ┆ [-0.536674, 0.10478, … -0.0837… ┆ 1           ┆ b          ┆ [-0.536674, 0.10478, … -0.0837… │
    │ 1     ┆ b    ┆ [-0.536674, 0.10478, … -0.0837… ┆ 2           ┆ c          ┆ [-0.21311, -0.030623, … 0.2618… │
    │ 1     ┆ b    ┆ [-0.536674, 0.10478, … -0.0837… ┆ 3           ┆ d          ┆ [-0.308025, 0.006694, … 0.5338… │
    │ 2     ┆ c    ┆ [-0.21311, -0.030623, … 0.2618… ┆ 2           ┆ c          ┆ [-0.21311, -0.030623, … 0.2618… │
    │ 2     ┆ c    ┆ [-0.21311, -0.030623, … 0.2618… ┆ 3           ┆ d          ┆ [-0.308025, 0.006694, … 0.5338… │
    │ 3     ┆ d    ┆ [-0.308025, 0.006694, … 0.5338… ┆ 3           ┆ d          ┆ [-0.308025, 0.006694, … 0.5338… │
    └───────┴──────┴─────────────────────────────────┴─────────────┴────────────┴─────────────────────────────────┘
    

    Cosine Similarity

    You can write the formula using Expressions e.g. list arithmetic, .list.sum() and Expr.sqrt().

    cosine_similarity = lambda x, y: (
        (x * y).list.sum() / (
            (x * x).list.sum().sqrt() * (y * y).list.sum().sqrt()
        )
    )
    
    out = (
       df.join_where(df, pl.col.index <= pl.col.index_right)
         .select(
            col = "col1",
            other = "col1_right",
            cosine = cosine_similarity(
               x = pl.col.col2,
               y = pl.col.col2_right
            )
         )
    )
    
    # out.collect()
    shape: (10, 3)
    ┌─────┬───────┬──────────┐
    │ col ┆ other ┆ cosine   │
    │ --- ┆ ---   ┆ ---      │
    │ str ┆ str   ┆ f64      │
    ╞═════╪═══════╪══════════╡
    │ a   ┆ a     ┆ 1.0      │
    │ a   ┆ b     ┆ 0.856754 │
    │ a   ┆ c     ┆ 0.827877 │
    │ a   ┆ d     ┆ 0.540282 │
    │ b   ┆ b     ┆ 1.0      │
    │ b   ┆ c     ┆ 0.752199 │
    │ b   ┆ d     ┆ 0.411564 │
    │ c   ┆ c     ┆ 1.0      │
    │ c   ┆ d     ┆ 0.889009 │
    │ d   ┆ d     ┆ 1.0      │
    └─────┴───────┴──────────┘
    

    Pivot

    You can vertically concat/stack the reverse pairings and then .pivot() for the matrix shape.

    pl.concat(
       [
          out, 
          out.filter(pl.col.col != pl.col.other).select(col="other", other="col", cosine="cosine")
       ]
    ).collect().pivot("other", index="col")
    
    shape: (4, 5)
    ┌─────┬──────────┬──────────┬──────────┬──────────┐
    │ col ┆ a        ┆ b        ┆ c        ┆ d        │
    │ --- ┆ ---      ┆ ---      ┆ ---      ┆ ---      │
    │ str ┆ f64      ┆ f64      ┆ f64      ┆ f64      │
    ╞═════╪══════════╪══════════╪══════════╪══════════╡
    │ a   ┆ 1.0      ┆ 0.856754 ┆ 0.827877 ┆ 0.540282 │
    │ b   ┆ 0.856754 ┆ 1.0      ┆ 0.752199 ┆ 0.411564 │
    │ c   ┆ 0.827877 ┆ 0.752199 ┆ 1.0      ┆ 0.889009 │
    │ d   ┆ 0.540282 ┆ 0.411564 ┆ 0.889009 ┆ 1.0      │
    └─────┴──────────┴──────────┴──────────┴──────────┘