Search code examples
pandasdataframeprotocolstyping

Is there a way to type hint a pandas object's index?


I'd like to type hint that a pandas dataframe must have a datetimeindex. I was hoping there might be some way to do this with protocols but looks like no. Something in the spirit of this:

class TSFrame(Protocol):
    index: pd.DatetimeIndex

def test(df: TSFrame):
    # Do stuff with df.index.methods_supported_by_dtidx_only
    pass

nontsdf = pd.DataFrame()
tsdf = pd.DataFrame(index=pd.DatetimeIndex(pd.date_range("2022-01-01", "2022-01-02")))
test(nontsdf)  # goal is for my interpreter to complain here
test(tsdf)  # and not complain here

My interpreter instead complains in both cases. Confusingly, if I create an analogous test on a generic class but where the type hint is int, it complains in neither case.

class IntWanted(Protocol):
    var: int

class TestClass:
    def __init__(self, var: Any) -> None:
        self.var = var

def foo(a: IntWanted) -> int:
    return a.var

good = TestClass(1)
bad = TestClass("x")
foo(good)
foo(bad) 

Other ways I can think of to treat these timeseries dataframes:

  1. Subclass dataframe and add validation that the index is a datetimeindex. Convert every df I have into an instance of this class and type hint that class everywhere. Would that solve the problem of mypy knowing its index has the attributes of a dtidx though? I think not. E.g.
class TSFrame(pd.DataFrame):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert isinstance(self.index, pd.DatetimeIndex)
  1. Make an entirely new object that has a df as one attribute, defined an index that is a datetimeindex created from the input df's index so it's known by mypy what type that index has. This feels heavy.
@attr.s(auto_attribs=True)
class TSFrame:
    df: pd.DataFrame

    def __attrs_post_init__(self):
        assert isinstance(self.index, pd.DatetimeIndex)

    @property
    def index(self) -> pd.DatetimeIndex:
        return pd.DatetimeIndex(self.df.index)

Ideas appreciated.


Solution

  • You can use pandera (and pandas-stub) to do pretty much whatever you want.

    1. pip install pandera[mypy]
    2. Create a mypy.ini file with:
    [mypy]
    plugins = pandera.mypy
    

    demo.py

    import pandera as pa
    import pandas as pd
    import numpy as np
    from pandera.typing import Index, DataFrame, Series
    
    class TSFrame(pa.DataFrameModel):
        idx: Index[pa.Timestamp] = pa.Field(check_name=False)
    
    @pa.check_types  # at runtime
    def test(df: DataFrame[TSFrame]):  # at compile time
        pass
    
    nontsdf = pd.DataFrame()
    tsdf = DataFrame[TSFrame](index=pd.DatetimeIndex(pd.date_range("2022-01-01", "2022-01-02")))
    test(nontsdf)
    test(tsdf)
    

    Usage:

    [...]$ mypy demo.py
    demo1.py:14: error: Argument 1 to "test" has incompatible type "pandas.core.frame.DataFrame"; expected "pandera.typing.pandas.DataFrame[TSFrame]"  [arg-type]
    Found 1 error in 1 file (checked 1 source file)
    
    [...]$ python demo.py
    ...
    pandera.errors.SchemaError: error in check_types decorator of function 'test': expected series 'None' to have type datetime64[ns], got int64
    

    More information: