Search code examples
pythonpython-polars

Correlation matrix like DataFrame in Polars


I have Polars dataframe

data = {
    "col1": ["a", "b", "c", "d"],
    "col2": [[-0.06066, 0.072485, 0.548874, 0.158507],
             [-0.536674, 0.10478, 0.926022, -0.083722],
             [-0.21311, -0.030623, 0.300583, 0.261814],
             [-0.308025, 0.006694, 0.176335, 0.533835]],
}

df = pl.DataFrame(data)

I want to calculate cosine similarity for each combination of column col1

The desired output should be the following:

┌─────────────────┬──────┬──────┬──────┬──────┐
│ col1_col2       ┆ a    ┆ b    ┆ c    ┆ d    │
│ ---             ┆ ---  ┆ ---  ┆ ---  ┆ ---  │
│ str             ┆ f64  ┆ f64  ┆ f64  ┆ f64  │
╞═════════════════╪══════╪══════╪══════╪══════╡
│ a               ┆ 1.0  ┆ 0.86 ┆ 0.83 ┆ 0.54 │
│ b               ┆ 0.86 ┆ 1.0  ┆ 0.75 ┆ 0.41 │
│ c               ┆ 0.83 ┆ 0.75 ┆ 1.0  ┆ 0.89 │
│ d               ┆ 0.54 ┆ 0.41 ┆ 0.89 ┆ 1.0  │
└─────────────────┴──────┴──────┴──────┴──────┘

Where each value represents cosine similarity between respective column values.

I tried to use pivot method

df.pivot(values="col2", index="col1", columns="col1", aggregate_function=cosine_similary)

However I'm getting the following error

'function' object has no attribute '_pyexpr'

I'm using following cosine similarity function

from numpy.linalg import norm

cosine_similarity = lambda a,b: (a @ b.T) / (norm(a)*norm(b))

However, I can use any implementation of it


Solution

  • You could cross join + filter to get the pairs. (i.e. combinations_with_replacements(..., r=2))

    And use expressions for the similarity calculation:

    x = pl.col("col2").flatten()
    y = pl.col("col2_right").flatten()
    
    row = pl.first().cum_count()
    
    cosine_similarity = (
       x.dot(y) / (x.pow(2).sum().sqrt() * y.pow(2).sum().sqrt())
    ).over(row)
    
    (df.join(df, how = "cross")
       .filter(pl.col("col1") <= pl.col("col1_right"))
       .select(
          col    = "col1",
          other  = "col1_right",
          cosine = cosine_similarity
       )
    )
    
    shape: (10, 3)
    ┌─────┬───────┬──────────┐
    │ col ┆ other ┆ cosine   │
    │ --- ┆ ---   ┆ ---      │
    │ str ┆ str   ┆ f64      │
    ╞═════╪═══════╪══════════╡
    │ a   ┆ a     ┆ 1.0      │
    │ a   ┆ b     ┆ 0.856754 │
    │ a   ┆ c     ┆ 0.827877 │
    │ a   ┆ d     ┆ 0.540282 │
    │ b   ┆ b     ┆ 1.0      │
    │ b   ┆ c     ┆ 0.752199 │
    │ b   ┆ d     ┆ 0.411564 │
    │ c   ┆ c     ┆ 1.0      │
    │ c   ┆ d     ┆ 0.889009 │
    │ d   ┆ d     ┆ 1.0      │
    └─────┴───────┴──────────┘
    

    You can then .pivot if desired.