I have Polars dataframe
data = {
"col1": ["a", "b", "c", "d"],
"col2": [[-0.06066, 0.072485, 0.548874, 0.158507],
[-0.536674, 0.10478, 0.926022, -0.083722],
[-0.21311, -0.030623, 0.300583, 0.261814],
[-0.308025, 0.006694, 0.176335, 0.533835]],
}
df = pl.DataFrame(data)
I want to calculate cosine similarity for each combination of column col1
The desired output should be the following:
┌─────────────────┬──────┬──────┬──────┬──────┐
│ col1_col2 ┆ a ┆ b ┆ c ┆ d │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞═════════════════╪══════╪══════╪══════╪══════╡
│ a ┆ 1.0 ┆ 0.86 ┆ 0.83 ┆ 0.54 │
│ b ┆ 0.86 ┆ 1.0 ┆ 0.75 ┆ 0.41 │
│ c ┆ 0.83 ┆ 0.75 ┆ 1.0 ┆ 0.89 │
│ d ┆ 0.54 ┆ 0.41 ┆ 0.89 ┆ 1.0 │
└─────────────────┴──────┴──────┴──────┴──────┘
Where each value represents cosine similarity between respective column values.
I'm using following cosine similarity function
from numpy.linalg import norm
cosine_similarity = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
I tried to use it with pivot
method
df.pivot(on="col1", values="col2", index="col1", aggregate_function=cosine_similarity)
However I'm getting the following error
AttributeError: 'function' object has no attribute '_pyexpr'
Update: Polars 1.8.0 added native list arithmetic allowing us to write a much more efficient cosine similarity expression.
We can add a row index and use
.join_where()
to generate the row "combinations".
df = df.with_row_index().lazy()
df.join_where(df, pl.col.index <= pl.col.index_right).collect()
shape: (10, 6)
┌───────┬──────┬─────────────────────────────────┬─────────────┬────────────┬─────────────────────────────────┐
│ index ┆ col1 ┆ col2 ┆ index_right ┆ col1_right ┆ col2_right │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u32 ┆ str ┆ list[f64] ┆ u32 ┆ str ┆ list[f64] │
╞═══════╪══════╪═════════════════════════════════╪═════════════╪════════════╪═════════════════════════════════╡
│ 0 ┆ a ┆ [-0.06066, 0.072485, … 0.15850… ┆ 0 ┆ a ┆ [-0.06066, 0.072485, … 0.15850… │
│ 0 ┆ a ┆ [-0.06066, 0.072485, … 0.15850… ┆ 1 ┆ b ┆ [-0.536674, 0.10478, … -0.0837… │
│ 0 ┆ a ┆ [-0.06066, 0.072485, … 0.15850… ┆ 2 ┆ c ┆ [-0.21311, -0.030623, … 0.2618… │
│ 0 ┆ a ┆ [-0.06066, 0.072485, … 0.15850… ┆ 3 ┆ d ┆ [-0.308025, 0.006694, … 0.5338… │
│ 1 ┆ b ┆ [-0.536674, 0.10478, … -0.0837… ┆ 1 ┆ b ┆ [-0.536674, 0.10478, … -0.0837… │
│ 1 ┆ b ┆ [-0.536674, 0.10478, … -0.0837… ┆ 2 ┆ c ┆ [-0.21311, -0.030623, … 0.2618… │
│ 1 ┆ b ┆ [-0.536674, 0.10478, … -0.0837… ┆ 3 ┆ d ┆ [-0.308025, 0.006694, … 0.5338… │
│ 2 ┆ c ┆ [-0.21311, -0.030623, … 0.2618… ┆ 2 ┆ c ┆ [-0.21311, -0.030623, … 0.2618… │
│ 2 ┆ c ┆ [-0.21311, -0.030623, … 0.2618… ┆ 3 ┆ d ┆ [-0.308025, 0.006694, … 0.5338… │
│ 3 ┆ d ┆ [-0.308025, 0.006694, … 0.5338… ┆ 3 ┆ d ┆ [-0.308025, 0.006694, … 0.5338… │
└───────┴──────┴─────────────────────────────────┴─────────────┴────────────┴─────────────────────────────────┘
You can write the formula using Expressions e.g. list arithmetic, .list.sum()
and Expr.sqrt()
.
cosine_similarity = lambda x, y: (
(x * y).list.sum() / (
(x * x).list.sum().sqrt() * (y * y).list.sum().sqrt()
)
)
out = (
df.join_where(df, pl.col.index <= pl.col.index_right)
.select(
col = "col1",
other = "col1_right",
cosine = cosine_similarity(
x = pl.col.col2,
y = pl.col.col2_right
)
)
)
# out.collect()
shape: (10, 3)
┌─────┬───────┬──────────┐
│ col ┆ other ┆ cosine │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ f64 │
╞═════╪═══════╪══════════╡
│ a ┆ a ┆ 1.0 │
│ a ┆ b ┆ 0.856754 │
│ a ┆ c ┆ 0.827877 │
│ a ┆ d ┆ 0.540282 │
│ b ┆ b ┆ 1.0 │
│ b ┆ c ┆ 0.752199 │
│ b ┆ d ┆ 0.411564 │
│ c ┆ c ┆ 1.0 │
│ c ┆ d ┆ 0.889009 │
│ d ┆ d ┆ 1.0 │
└─────┴───────┴──────────┘
You can vertically concat/stack the reverse pairings and then .pivot()
for the matrix shape.
pl.concat(
[
out,
out.filter(pl.col.col != pl.col.other).select(col="other", other="col", cosine="cosine")
]
).collect().pivot("other", index="col")
shape: (4, 5)
┌─────┬──────────┬──────────┬──────────┬──────────┐
│ col ┆ a ┆ b ┆ c ┆ d │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞═════╪══════════╪══════════╪══════════╪══════════╡
│ a ┆ 1.0 ┆ 0.856754 ┆ 0.827877 ┆ 0.540282 │
│ b ┆ 0.856754 ┆ 1.0 ┆ 0.752199 ┆ 0.411564 │
│ c ┆ 0.827877 ┆ 0.752199 ┆ 1.0 ┆ 0.889009 │
│ d ┆ 0.540282 ┆ 0.411564 ┆ 0.889009 ┆ 1.0 │
└─────┴──────────┴──────────┴──────────┴──────────┘