I have a dataframe with people
and the food
they like:
df_class = pl.DataFrame(
{
'people': ['alan', 'bob', 'charlie'],
'food': [['orange', 'apple'], ['banana', 'cherry'], ['banana', 'grape']]
}
)
print(df_class)
shape: (3, 2)
┌─────────┬──────────────────────┐
│ people ┆ food │
│ --- ┆ --- │
│ str ┆ list[str] │
╞═════════╪══════════════════════╡
│ alan ┆ ["orange", "apple"] │
│ bob ┆ ["banana", "cherry"] │
│ charlie ┆ ["banana", "grape"] │
└─────────┴──────────────────────┘
And, I have a data structure with animals
and the things they like to eat:
animals = [
('squirrel', ('acorn', 'almond')),
('parrot', ('cracker', 'grape', 'guava')),
('dog', ('chicken', 'bone')),
('monkey', ('banana', 'plants'))
]
I want to add a new column pets
in df_class
, such that pets
is a list of the animals that have at least one food in common with the corresponding person:
df_class.with_columns(pets=???) # <-- not sure what to do here
shape: (3, 3)
┌─────────┬──────────────────────┬──────────────────────┐
│ people ┆ food ┆ pets │
│ --- ┆ --- ┆ --- │
│ str ┆ list[str] ┆ list[str] │
╞═════════╪══════════════════════╪══════════════════════╡
│ alan ┆ ["orange", "apple"] ┆ [] │
│ bob ┆ ["banana", "cherry"] ┆ ["monkey"] │
│ charlie ┆ ["banana", "grape"] ┆ ["monkey", "parrot"] │
└─────────┴──────────────────────┴──────────────────────┘
animals
, in case for e.g. this is easier to do with some sort of set intersectionSide note: my title seems kind of clunky and I'm open to suggestions to reword so that it's easier to find by others that might be trying to solve a similar problem.
It looks like a join
on the exploded lists.
It can be kept as a "single expression" by putting it inside .map_batches()
df_class.with_columns(pets =
pl.col("food").map_batches(lambda col:
pl.LazyFrame(col)
.with_row_index()
.explode("food")
.join(
pl.LazyFrame(animals, schema=["pets", "food"]).explode("food"),
on = "food",
how = "left"
)
.group_by("index", maintain_order=True)
.agg(
pl.col("pets").unique().drop_nulls()
)
.collect()
.get_column("pets")
)
)
shape: (3, 3)
┌─────────┬──────────────────────┬──────────────────────┐
│ people ┆ food ┆ pets │
│ --- ┆ --- ┆ --- │
│ str ┆ list[str] ┆ list[str] │
╞═════════╪══════════════════════╪══════════════════════╡
│ alan ┆ ["orange", "apple"] ┆ [] │
│ bob ┆ ["banana", "cherry"] ┆ ["monkey"] │
│ charlie ┆ ["banana", "grape"] ┆ ["monkey", "parrot"] │
└─────────┴──────────────────────┴──────────────────────┘
We add a row index to the "left" frame and explode the lists.
(The row index will allow us to rebuild the rows later on.)
df_class_long = df_class.with_row_index().explode("food")
# shape: (6, 3)
# ┌───────┬─────────┬────────┐
# │ index ┆ people ┆ food │
# │ --- ┆ --- ┆ --- │
# │ u32 ┆ str ┆ str │
# ╞═══════╪═════════╪════════╡
# │ 0 ┆ alan ┆ orange │
# │ 0 ┆ alan ┆ apple │
# │ 1 ┆ bob ┆ banana │
# │ 1 ┆ bob ┆ cherry │
# │ 2 ┆ charlie ┆ banana │
# │ 2 ┆ charlie ┆ grape │
# └───────┴─────────┴────────┘
df_pets_long = pl.DataFrame(animals, schema=["pets", "food"]).explode("food")
# shape: (9, 2)
# ┌──────────┬─────────┐
# │ pets ┆ food │
# │ --- ┆ --- │
# │ str ┆ str │
# ╞══════════╪═════════╡
# │ squirrel ┆ acorn │
# │ squirrel ┆ almond │
# │ parrot ┆ cracker │
# │ parrot ┆ grape │
# │ parrot ┆ guava │
# │ dog ┆ chicken │
# │ dog ┆ bone │
# │ monkey ┆ banana │
# │ monkey ┆ plants │
# └──────────┴─────────┘
We then use a Left Join to find the "intersections" (whilst keeping all the rows from the left side).
df_class_long.join(df_pets_long, on="food", how="left")
# shape: (6, 4)
# ┌───────┬─────────┬────────┬────────┐
# │ index ┆ people ┆ food ┆ pets │
# │ --- ┆ --- ┆ --- ┆ --- │
# │ u32 ┆ str ┆ str ┆ str │
# ╞═══════╪═════════╪════════╪════════╡
# │ 0 ┆ alan ┆ orange ┆ null │
# │ 0 ┆ alan ┆ apple ┆ null │
# │ 1 ┆ bob ┆ banana ┆ monkey │
# │ 1 ┆ bob ┆ cherry ┆ null │
# │ 2 ┆ charlie ┆ banana ┆ monkey │
# │ 2 ┆ charlie ┆ grape ┆ parrot │
# └───────┴─────────┴────────┴────────┘
We can then rebuild the "rows" with .group_by()
.unique()
pets only.(df_class_long.join(df_pets_long, on="food", how="left")
.group_by("index", maintain_order=True) # we need to retain row order
.agg(pl.col("pets").unique().drop_nulls())
)
# shape: (3, 2)
# ┌───────┬──────────────────────┐
# │ index ┆ pets │
# │ --- ┆ --- │
# │ u32 ┆ list[str] │
# ╞═══════╪══════════════════════╡
# │ 0 ┆ [] │
# │ 1 ┆ ["monkey"] │
# │ 2 ┆ ["monkey", "parrot"] │
# └───────┴──────────────────────┘