Search code examples
python-polars

Polars list of values based on intersection of different column list with another dataset


I have a dataframe with people and the food they like:

df_class = pl.DataFrame(
    {
        'people': ['alan', 'bob', 'charlie'],
        'food': [['orange', 'apple'], ['banana', 'cherry'], ['banana', 'grape']]
    }
)
print(df_class)

shape: (3, 2)
┌─────────┬──────────────────────┐
│ people  ┆ food                 │
│ ---     ┆ ---                  │
│ str     ┆ list[str]            │
╞═════════╪══════════════════════╡
│ alan    ┆ ["orange", "apple"]  │
│ bob     ┆ ["banana", "cherry"] │
│ charlie ┆ ["banana", "grape"]  │
└─────────┴──────────────────────┘

And, I have a data structure with animals and the things they like to eat:

animals = [
    ('squirrel', ('acorn', 'almond')),
    ('parrot', ('cracker', 'grape', 'guava')),
    ('dog', ('chicken', 'bone')),
    ('monkey', ('banana', 'plants'))
]

I want to add a new column pets in df_class, such that pets is a list of the animals that have at least one food in common with the corresponding person:

df_class.with_columns(pets=???)   # <-- not sure what to do here

shape: (3, 3)
┌─────────┬──────────────────────┬──────────────────────┐
│ people  ┆ food                 ┆ pets                 │
│ ---     ┆ ---                  ┆ ---                  │
│ str     ┆ list[str]            ┆ list[str]            │
╞═════════╪══════════════════════╪══════════════════════╡
│ alan    ┆ ["orange", "apple"]  ┆ []                   │
│ bob     ┆ ["banana", "cherry"] ┆ ["monkey"]           │
│ charlie ┆ ["banana", "grape"]  ┆ ["monkey", "parrot"] │
└─────────┴──────────────────────┴──────────────────────┘
  • I have some flexibility in the data structure for animals, in case for e.g. this is easier to do with some sort of set intersection
  • the order of pets is unimportant
  • pets should contain unique values
  • I'm looking for a single expression that would achieve the desired result, so as to fit within a larger framework of a list of expressions that perform transformations on other columns of my actual dataset

Side note: my title seems kind of clunky and I'm open to suggestions to reword so that it's easier to find by others that might be trying to solve a similar problem.


Solution

  • It looks like a join on the exploded lists.

    It can be kept as a "single expression" by putting it inside .map_batches()

    df_class.with_columns(pets = 
       pl.col("food").map_batches(lambda col:
          pl.LazyFrame(col)
            .with_row_index()
            .explode("food")
            .join(
               pl.LazyFrame(animals, schema=["pets", "food"]).explode("food"),
               on = "food",
               how = "left"
            )
            .group_by("index", maintain_order=True)
            .agg(
               pl.col("pets").unique().drop_nulls()
            )
            .collect()
            .get_column("pets")
       )
    )
    
    shape: (3, 3)
    ┌─────────┬──────────────────────┬──────────────────────┐
    │ people  ┆ food                 ┆ pets                 │
    │ ---     ┆ ---                  ┆ ---                  │
    │ str     ┆ list[str]            ┆ list[str]            │
    ╞═════════╪══════════════════════╪══════════════════════╡
    │ alan    ┆ ["orange", "apple"]  ┆ []                   │
    │ bob     ┆ ["banana", "cherry"] ┆ ["monkey"]           │
    │ charlie ┆ ["banana", "grape"]  ┆ ["monkey", "parrot"] │
    └─────────┴──────────────────────┴──────────────────────┘
    

    Explanation

    We add a row index to the "left" frame and explode the lists.

    (The row index will allow us to rebuild the rows later on.)

    df_class_long = df_class.with_row_index().explode("food")
    
    # shape: (6, 3)
    # ┌───────┬─────────┬────────┐
    # │ index ┆ people  ┆ food   │
    # │ ---   ┆ ---     ┆ ---    │
    # │ u32   ┆ str     ┆ str    │
    # ╞═══════╪═════════╪════════╡
    # │ 0     ┆ alan    ┆ orange │
    # │ 0     ┆ alan    ┆ apple  │
    # │ 1     ┆ bob     ┆ banana │
    # │ 1     ┆ bob     ┆ cherry │
    # │ 2     ┆ charlie ┆ banana │
    # │ 2     ┆ charlie ┆ grape  │
    # └───────┴─────────┴────────┘
    
    df_pets_long = pl.DataFrame(animals, schema=["pets", "food"]).explode("food")
    
    # shape: (9, 2)
    # ┌──────────┬─────────┐
    # │ pets     ┆ food    │
    # │ ---      ┆ ---     │
    # │ str      ┆ str     │
    # ╞══════════╪═════════╡
    # │ squirrel ┆ acorn   │
    # │ squirrel ┆ almond  │
    # │ parrot   ┆ cracker │
    # │ parrot   ┆ grape   │
    # │ parrot   ┆ guava   │
    # │ dog      ┆ chicken │
    # │ dog      ┆ bone    │
    # │ monkey   ┆ banana  │
    # │ monkey   ┆ plants  │
    # └──────────┴─────────┘
    

    We then use a Left Join to find the "intersections" (whilst keeping all the rows from the left side).

    df_class_long.join(df_pets_long, on="food", how="left")
    
    # shape: (6, 4)
    # ┌───────┬─────────┬────────┬────────┐
    # │ index ┆ people  ┆ food   ┆ pets   │
    # │ ---   ┆ ---     ┆ ---    ┆ ---    │
    # │ u32   ┆ str     ┆ str    ┆ str    │
    # ╞═══════╪═════════╪════════╪════════╡
    # │ 0     ┆ alan    ┆ orange ┆ null   │
    # │ 0     ┆ alan    ┆ apple  ┆ null   │
    # │ 1     ┆ bob     ┆ banana ┆ monkey │
    # │ 1     ┆ bob     ┆ cherry ┆ null   │
    # │ 2     ┆ charlie ┆ banana ┆ monkey │
    # │ 2     ┆ charlie ┆ grape  ┆ parrot │
    # └───────┴─────────┴────────┴────────┘
    

    We can then rebuild the "rows" with .group_by()

    (df_class_long.join(df_pets_long, on="food", how="left")
      .group_by("index", maintain_order=True) # we need to retain row order
      .agg(pl.col("pets").unique().drop_nulls())
    )
    
    # shape: (3, 2)
    # ┌───────┬──────────────────────┐
    # │ index ┆ pets                 │
    # │ ---   ┆ ---                  │
    # │ u32   ┆ list[str]            │
    # ╞═══════╪══════════════════════╡
    # │ 0     ┆ []                   │
    # │ 1     ┆ ["monkey"]           │
    # │ 2     ┆ ["monkey", "parrot"] │
    # └───────┴──────────────────────┘