Search code examples
pythondataframelazy-evaluationpython-polars

Why does collecting a LazyFrame before joins in Polars solve my issue with index discrepancies?


Here is an example that is runnable and demonstrates the issue. Initial LazyFrame includes pairwise distances between points in a plane. Since A->B distance is equal to B->A distance I keep only unique pairs and then construct the labels LazyFrame to keep only unique labels.I am joining the initial LazyFrame with the Labels LazyFrame to retrieve the corresponding indices for col_1 and col_2. For instance, in the resulting dataframe, distance from an index value to the same index value (e.g., A-A (0->0), B-B (1->1), C-C (2->2) ) should be zero, but it is not when the LazyFrame is not collected before the joins.

import polars as pl
import numpy as np

data = {
    "col_1": ["A", "A", "A", "B", "B", "B", "C", "C", "C"],
    "col_2": ["B", "C", "A", "A", "C", "B", "A", "B", "C"],
    "col_3": [1.0, 2.0, 0.0, 1.0, 1.5, 0.0, 2.0, 1.5, 0.0]
}

# Create LazyFrame
pairwise_distances = pl.LazyFrame(data)

# Function to concatenate and sort two elements and return a string with '||'
def concat_and_sort(a, b):
    return "||".join(sorted([a, b]))

# Add pairs column and get unique pairs
pairwise_distances = pairwise_distances.with_columns(
    pl.struct(["col_1", "col_2"]).map_elements(
        lambda x: concat_and_sort(x["col_1"], x["col_2"]),
        return_dtype=pl.String
    ).alias("pairs")
).unique(subset="pairs").select(pl.col("col_1"), pl.col("col_2"), pl.col("col_3"))

# Concatenate col_1 and col_2, and get unique labels with index
labels = (
    pl.concat([
        pairwise_distances.select(pl.col("col_1").alias("label")),
        pairwise_distances.select(pl.col("col_2").alias("label"))
    ])
    .unique(keep="first")
    .with_row_count(name="index")
)

# Collect labels LazyFrame to DataFrame
labels_df = labels.collect()
pairwise_distances_df = pairwise_distances.collect()

# Join to get index_1 (without collecting labels)
data_joined_direct = pairwise_distances.join(
    labels,
    left_on="col_1",
    right_on="label",
    how="left"
).rename({"index": "index_1"})

# Join to get index_2 (without collecting labels)
data_joined_direct = data_joined_direct.join(
    labels,
    left_on="col_2",
    right_on="label",
    how="left"
).rename({"index": "index_2"})

# Select relevant columns (without collecting labels)
result_direct = data_joined_direct.select(["index_1", "index_2", "col_3"])

# Collect the result to execute the lazy operations (without collecting labels first)
# Result without collecting labels first
result_direct_df = result_direct.collect()

# Join to get index_1 (with collecting labels first)
data_joined_collected = pairwise_distances_df.join(
    labels_df,
    left_on="col_1",
    right_on="label",
).rename({"index": "index_1"})

# Join to get index_2 (with collecting labels first)
data_joined_collected = data_joined_collected.join(
    labels_df,
    left_on="col_2",
    right_on="label",
).rename({"index": "index_2"})

# Select relevant columns (with collecting labels first)
result_collected = data_joined_collected.select(["index_1", "index_2", "col_3"])

Original Post

I'm using Polars to process a dataset where I need to create unique labels from two columns and then perform joins to get indices for those labels. However, I noticed that if I perform the joins directly on the LazyFrame, the indices seem to be incorrect. When I collect the LazyFrame into a DataFrame before performing the joins, the indices are correct.

Here are the relevant code snippets:

  1. Creating the labels LazyFrame:
import polars as pl

# Assume data is a LazyFrame
data = pl.scan_csv(
    source=source_filepath,
    separator='\t',
    has_header=False,
)

# Concatenate col_1 and col_2, and get unique labels with index
labels = (
    pl.concat([
        data.select(pl.col("col_1").alias("label")),
        data.select(pl.col("col_2").alias("label"))
    ])
    .unique(keep="first")
    .with_row_count(name="label_index")
)

  1. Joining without collecting (This gives incorrect indices):
# Join to get index_1
data = data.join(
    labels,
    left_on="col_1",
    right_on="label",
).rename({"label_index": "index_1"})

# Join to get index_2
data = data.join(
    labels,
    left_on="col_2",
    right_on="label",
).rename({"label_index": "index_2"})

result = data.select(["index_1", "index_2", "col_3"])
result_df = result.collect()

  1. Joining after collecting labels (This gives correct index_1 and index_2 values):
# Collect labels and data LazyFrame to DataFrame
labels_df = labels.collect()
data_df = data.collect()

# Join to get index_1
data_df = data_df.join(
    labels_df,
    left_on="col_1",
    right_on="label",
    how="left"
).rename({"label_index": "index_1"})

# Join to get index_2
data_df = data_df.join(
    labels_df,
    left_on="col_2",
    right_on="label",
    how="left"
).rename({"label_index": "index_2"})

result_df = data_df.select(["index_1", "index_2", "col_3"])

Why is there this discrepancy between using a LazyFrame directly for joins and collecting it before performing the joins? How can I ensure correct behavior without needing to collect the LazyFrame prematurely?

Any insights into why this happens and how to resolve it would be greatly appreciated!


Solution

  • Answer to the question

    The reason you get different answers is that unique doesn't preserve order by default. Since your labels index will be different depending on the order that will mean you get different results. The only thing you need to do to fix this is:

    labels = (
        pl.concat([
            pairwise_distances.select(pl.col("col_1").alias("label")),
            pairwise_distances.select(pl.col("col_2").alias("label"))
        ])
        .sort('label')
         ## Add maintain_order in this unique
        .unique(keep="first", maintain_order=True)
        .with_row_count(name="index")
    )
    

    Unsolicited map_elements prevention

    You can expressionize (I'm making it a word, if it isn't already) your pairwise_distances as either

    (
        pairwise_distances
        .filter(pl.col('col_1')<=pl.col('col_2'))
        .unique(['col_1','col_2'])
        .collect()
    )
    

    This one assumes that you have, for example, an A,B for every B,A. If you're worried that you might have a B,A but not an A,B in real data then you can do

    (
        pairwise_distances
        .select(
            pl.min_horizontal('col_1','col_2').alias('col_1'), 
            pl.max_horizontal('col_1','col_2').alias('col_2'),
            'col_3')
        .unique(['col_1','col_2'])
    )