Search code examples
pythonpandaspython-polars

Polars - Perform matrix inner product on lazy frames to produce sparse representation of gram matrix


Suppose we have a polars dataframe like:

df = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}).lazy()

shape: (3, 2)
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1   ┆ 3   │
│ 2   ┆ 4   │
│ 3   ┆ 5   │
└─────┴─────┘

I would like to X^TX the matrix while preserving the sparse matrix format for arrow* - in pandas I would do something like:

pdf = df.collect().to_pandas()
numbers = pdf[["a", "b"]]
(numbers.T @ numbers).melt(ignore_index=False)

  variable  value
a        a     14
b        a     26
a        b     26
b        b     50

I did something like this in polars:

df.select(
    (pl.col("a") * pl.col("a")).sum().alias("aa"),
    (pl.col("a") * pl.col("b")).sum().alias("ab"),
    (pl.col("b") * pl.col("a")).sum().alias("ba"),
    (pl.col("b") * pl.col("b")).sum().alias("bb"),
).unpivot().collect()

shape: (4, 2)
┌──────────┬───────┐
│ variable ┆ value │
│ ---      ┆ ---   │
│ str      ┆ i64   │
╞══════════╪═══════╡
│ aa       ┆ 14    │
│ ab       ┆ 26    │
│ ba       ┆ 26    │
│ bb       ┆ 50    │
└──────────┴───────┘

Which is almost there but not quite. This is a hack to get around the fact that I can't store lists as the column names (and then I could unnest them to become two different columns representing the x and y axis of the matrix). Is there a way to get the same format as shown in the pandas example?

*arrow is a columnar data format which means it's performant when scaled across rows but not across columns, which is why I think the sparse matrix representation is better if I want to use the results of the gram matrix chained with pl.LazyFrames later down the graph. I could be wrong though!


Solution

  • Polars doesn't have matrix multiplication, but we can tweak your algorithm slightly to accomplish what we need:

    • use the built-in dot expression
    • calculate each inner product only once, since <a, b> = <b, a>. We'll use Python's combinations_with_replacement iterator from itertools to accomplish this.
    • automatically generate the list of expressions that will run in parallel

    Let's expand your data a bit:

    from itertools import combinations_with_replacement
    import polars as pl
    
    df = pl.DataFrame(
        {"a": [1, 2, 3, 4, 5], "b": [3, 4, 5, 6, 7], "c": [5, 6, 7, 8, 9]}
    ).lazy()
    df.collect()
    
    shape: (5, 3)
    ┌─────┬─────┬─────┐
    │ a   ┆ b   ┆ c   │
    │ --- ┆ --- ┆ --- │
    │ i64 ┆ i64 ┆ i64 │
    ╞═════╪═════╪═════╡
    │ 1   ┆ 3   ┆ 5   │
    │ 2   ┆ 4   ┆ 6   │
    │ 3   ┆ 5   ┆ 7   │
    │ 4   ┆ 6   ┆ 8   │
    │ 5   ┆ 7   ┆ 9   │
    └─────┴─────┴─────┘
    

    The algorithm would be as follows:

    expr_list = [
        pl.col(col1).dot(pl.col(col2)).alias(col1 + "|" + col2)
        for col1, col2 in combinations_with_replacement(df.columns, 2)
    ]
    
    dot_prods = (
        df
        .select(expr_list)
        .unpivot()
        .with_columns(
            pl.col('variable').str.split_exact('|', 1)
        )
        .unnest('variable')
    )
    
    result = (
        pl.concat([
            dot_prods,
            dot_prods
            .filter(pl.col('field_0') != pl.col('field_1'))
            .select('field_1', 'field_0', 'value')
            .rename({'field_0':'field_1', 'field_1': 'field_0'})
            ],
        )
        .sort('field_0', 'field_1')
    )
    result.collect()
    
    shape: (9, 3)
    ┌─────────┬─────────┬───────┐
    │ field_0 ┆ field_1 ┆ value │
    │ ---     ┆ ---     ┆ ---   │
    │ str     ┆ str     ┆ i64   │
    ╞═════════╪═════════╪═══════╡
    │ a       ┆ a       ┆ 55    │
    │ a       ┆ b       ┆ 85    │
    │ a       ┆ c       ┆ 115   │
    │ b       ┆ a       ┆ 85    │
    │ b       ┆ b       ┆ 135   │
    │ b       ┆ c       ┆ 185   │
    │ c       ┆ a       ┆ 115   │
    │ c       ┆ b       ┆ 185   │
    │ c       ┆ c       ┆ 255   │
    └─────────┴─────────┴───────┘
    

    Couple of notes:

    • I'm assuming that a pipe would be an appropriate delimiter for your column names.
    • The use of Python bytecode and iterator will not significantly impair performance. It is only used to generate the list of expressions, not run any calculations.