Search code examples
pythonpython-polars

How can I recreate the following queries using Polars syntax?


I am currently trying to learn the Polars syntax by working through SQL/Pandas queries that I have done in the past.

Here is the SQL Query:

WITH fav_item_cte AS
(
    SELECT 
    s.customer_id, 
    m.product_name, 
    COUNT(m.product_id) AS order_count,
        DENSE_RANK() OVER(PARTITION BY s.customer_id ORDER BY COUNT(s.customer_id) DESC) AS rank
FROM dannys_diner.menu AS m
JOIN dannys_diner.sales AS s
    ON m.product_id = s.product_id
GROUP BY s.customer_id, m.product_name
)

SELECT 
  customer_id, 
  product_name, 
  order_count
FROM fav_item_cte 
WHERE rank = 1;

Here is the Pandas DataFrame after the initial inner join to allow for a MRE.

df = pd.DataFrame(
    {
        "customer_id": ['A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'A', 'B', 'B', 'C',
       'C', 'C'],
        "order_id": [
            datetime.date(2021, 1, 1), datetime.date(2021, 1, 4),
       datetime.date(2021, 1, 11), datetime.date(2021, 1, 1),
       datetime.date(2021, 1, 7), datetime.date(2021, 1, 1),
       datetime.date(2021, 1, 2), datetime.date(2021, 1, 10),
       datetime.date(2021, 1, 11), datetime.date(2021, 1, 11),
       datetime.date(2021, 1, 16), datetime.date(2021, 2, 1),
       datetime.date(2021, 1, 1), datetime.date(2021, 1, 1),
       datetime.date(2021, 1, 7)
        ],
        "join_date": [1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3],
        "product_name": ['sushi', 'sushi', 'sushi', 'curry', 'curry', 'curry', 'curry',
       'ramen', 'ramen', 'ramen', 'ramen', 'ramen', 'ramen', 'ramen',
       'ramen'],
        "price": [10, 10, 10, 15, 15, 15, 15, 12, 12, 12, 12, 12, 12, 12, 12],
    }
)

Here is the Pandas Code that I used:

df = (df
    .groupby(["customer_id", "product_name"])
    .agg(order_count=("product_id", "count"))
    .reset_index()
    .assign(
        rank=lambda df_: df_.groupby("customer_id")["order_count"].rank(
            method="dense", ascending=False
        )
    )
    .query("rank == 1")
    .sort_values(["customer_id", "product_name"])
)

And here is the output that I was seeking:

df = pd.DataFrame(
    {
        "customer_id": ['A', 'B', 'B', 'B', 'C'],
        "product_name": ['ramen', 'curry', 'ramen', 'sushi', 'ramen'],
        "order_count": [3, 2, 2, 2, 3],
    "rank": [1, 1, 1, 1, 1]

    }
)

Here is the Polars code that I have so far.

dfp = pl.DataFrame(
    {
        "customer_id": ['A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'A', 'B', 'B', 'C',
       'C', 'C'],
        "order_id": [
            datetime.date(2021, 1, 1), datetime.date(2021, 1, 4),
       datetime.date(2021, 1, 11), datetime.date(2021, 1, 1),
       datetime.date(2021, 1, 7), datetime.date(2021, 1, 1),
       datetime.date(2021, 1, 2), datetime.date(2021, 1, 10),
       datetime.date(2021, 1, 11), datetime.date(2021, 1, 11),
       datetime.date(2021, 1, 16), datetime.date(2021, 2, 1),
       datetime.date(2021, 1, 1), datetime.date(2021, 1, 1),
       datetime.date(2021, 1, 7)
        ],
        "join_date": [1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3],
        "product_name": ['sushi', 'sushi', 'sushi', 'curry', 'curry', 'curry', 'curry',
       'ramen', 'ramen', 'ramen', 'ramen', 'ramen', 'ramen', 'ramen',
       'ramen'],
        "price": [10, 10, 10, 15, 15, 15, 15, 12, 12, 12, 12, 12, 12, 12, 12],
    }
)
dfp = (
    (dfp
        .group_by("customer_id", "product_name")
        .agg(order_count=pl.col("product_name").count())
    )
    .sort("customer_id", "order_count", descending=[False, True])
    .with_columns(
        pl.col("order_count").rank(method="dense", descending=True).alias("rank")
    )
    .filter(pl.col("rank") == 1)
)

My Polars query doesn't account for/pick up on the repeated "customer B" orders that are tied for 3 items ordered twice each. Here is that output minus the .filter(pl.col("rank") == 1)

shape: (7, 4)
┌─────────────┬──────────────┬─────────────┬──────┐
│ customer_id ┆ product_name ┆ order_count ┆ rank │
│ ---         ┆ ---          ┆ ---         ┆ ---  │
│ str         ┆ str          ┆ u32         ┆ u32  │
╞═════════════╪══════════════╪═════════════╪══════╡
│ A           ┆ ramen        ┆ 3           ┆ 1    │
│ A           ┆ curry        ┆ 2           ┆ 2    │
│ A           ┆ sushi        ┆ 1           ┆ 3    │
│ B           ┆ curry        ┆ 2           ┆ 2    │
│ B           ┆ ramen        ┆ 2           ┆ 2    │
│ B           ┆ sushi        ┆ 2           ┆ 2    │
│ C           ┆ ramen        ┆ 3           ┆ 1    │
└─────────────┴──────────────┴─────────────┴──────┘

Edit: Based on @jqurious comment below, the following might be the best translation but I am still curious if it is the best way to do this in polars?


dfp = (
    (
        dfp
        .group_by("customer_id", "product_name")
        .agg(order_count=pl.col("product_name").count())
    )
    .sort("customer_id", "order_count", descending=[False, True])
    .with_columns(
        pl.col("order_count")
        .rank(method="dense", descending=True)
        .over("customer_id")
        .alias("rank")
    )
    .filter(pl.col("rank") == 1)
)
shape: (5, 4)
┌─────────────┬──────────────┬─────────────┬──────┐
│ customer_id ┆ product_name ┆ order_count ┆ rank │
│ ---         ┆ ---          ┆ ---         ┆ ---  │
│ str         ┆ str          ┆ u32         ┆ u32  │
╞═════════════╪══════════════╪═════════════╪══════╡
│ A           ┆ ramen        ┆ 3           ┆ 1    │
│ B           ┆ sushi        ┆ 2           ┆ 1    │
│ B           ┆ ramen        ┆ 2           ┆ 1    │
│ B           ┆ curry        ┆ 2           ┆ 1    │
│ C           ┆ ramen        ┆ 3           ┆ 1    │
└─────────────┴──────────────┴─────────────┴──────┘

Solution

  • As per the comments, it seems you want to keep rows with the "min rank". Or put another way:

    The most common "product_name(s)" per "customer_id" (i.e. the .mode()) along with their count(s).

    What comes to mind is:

    • add the len of each customer_id, product_name group
    • keep only the max len rows
    • group_by to remove duplicates
    • aggregate with len()
    (df.with_columns(pl.len().over("customer_id", "product_name")) 
       .filter(pl.col.len == pl.col.len.max().over("customer_id")) 
       .group_by("customer_id", "product_name")
       .len()
    )
    
    shape: (5, 3)
    ┌─────────────┬──────────────┬─────┐
    │ customer_id ┆ product_name ┆ len │
    │ ---         ┆ ---          ┆ --- │
    │ str         ┆ str          ┆ u32 │
    ╞═════════════╪══════════════╪═════╡
    │ B           ┆ curry        ┆ 2   │
    │ C           ┆ ramen        ┆ 3   │
    │ B           ┆ sushi        ┆ 2   │
    │ B           ┆ ramen        ┆ 2   │
    │ A           ┆ ramen        ┆ 3   │
    └─────────────┴──────────────┴─────┘
    

    There's probably quite a few different ways one could approach it.