Search code examples
pythonpython-polars

Python Polars: Apply function to two columns and an argument


Intro

In Polars I would like to do quite complex queries, and I would like to simplify the process by dividing the operations into methods. Before I can do that, I need to find out how to provide these function with multiple columns and variables.

Example Data

# Libraries
import polars as pl
from datetime import datetime

# Data
test_data = pl.DataFrame({
    "class": ['A', 'A', 'A', 'B', 'B', 'C'],
    "date": [datetime(2020, 1, 31), datetime(2020, 2, 28), datetime(2021, 1, 31),
              datetime(2022, 1, 31), datetime(2023, 2, 28),
              datetime(2020, 1, 31)],
    "status": [1,0,1,0,1,0]
})

The Problem

For each group, I would like to know if a reference date overlaps with the year-month in the date column of the dataframe.

I would like to do something like this.

# Some date
reference_date = datetime(2020, 1, 2)

# What I would expect the query to look like
(test_data
 .group_by("class")
 .agg(
    pl.col("status").count().alias("row_count"), #just to show code that works
    pl.lit(reference_date).alias("reference_date"),
    pl.col("date", "status")
      .map_elements(lambda group: myfunc(group, reference_date))
      .alias("point_in_time_status")
 )
)

# The desired output
pl.DataFrame({
  "class": ['A', 'B', 'C'],
  "reference_date": [datetime(2020, 1, 2), datetime(2020, 1, 2), datetime(2020, 1, 2)],
  "point_in_time_status": [1,0,0]
})

But I can simply not find any solutions for doing operations on groups. Some suggest using pl.struct, but that just outputs some weird object without columns or anything to work with.

Example in R of the same operation

# Loading library
library(tidyverse)

# Creating dataframe
df <- data.frame(
 date = c(as.Date("2020-01-31"), 
          as.Date("2020-02-28"), as.Date("2021-01-31"), 
          as.Date("2022-01-31"), as.Date("2023-02-28"), 
          as.Date("2020-01-31")), 
 status = c(1,0,1,0,1,0), 
 class = c("A","A","A","B","B","C"))

# Finding status in overlapping months
ref_date = as.Date("2020-01-02")

df %>%
  group_by("class") %>%
  filter(format(date, "%Y-%m") == format(ref_date, "%Y-%m")) %>%
  filter(status == 1)

Solution

  • This should work with expressions

    reference_date = datetime(2020, 1, 2)
    (
        test_data
        .group_by('class', maintain_order=True)
        .agg(
           point_in_time_status = (
             (pl.col('date').dt.month_start() == pl.lit(reference_date).dt.month_start()) & 
             (pl.col('status')==1)
           ).any(),
           reference_date = pl.lit(reference_date)
        )
    )
    

    I'm using the month_start method instead of converting to a string format as strings aren't especially performant if they can be avoided. You can see in the agg that we're looking for times when the first day of date's month is the same as the first day of reference date's month & if status == 1. That is all in parenthesis and then the aggregate any function is applied to that which applies per class. Lastly, we add in the reference_date column to get that in the output. You can interchange the order of those if you like.

    You can make that into a method but you should do it with polars expressions otherwise you're going to lose the efficiency gains that polars brings to the table. You can then monkey patch those to the pl.Expr namespace or create your own namespace

    As an example you could do:

    def myFunc(self, reference_date, status):
        return (
            (self.dt.month_start()==reference_date.dt.month_start()) &
            (status==1)
        ).any()
    pl.Expr.myFunc=myFunc
    (
        test_data
        .group_by('class', maintain_order=True)
        .agg(
            point_in_time_status = pl.col('date').myFunc(
             pl.lit(reference_date), pl.col('status')
            ),
            reference_date=pl.lit(reference_date)
        )
    )