Search code examples
pythonmypypython-typingpython-polarspyright

Polars API registering and type checkers


I'm consistently getting type errors from either mypy or pyright when using polars namespace registration functions. Is there a way I can avoid type-checker errors other than hinting # type: ignore[attr-defined] every time I'm using a function from my custom namespace?

Example follows from the official documentation https://docs.pola.rs/py-polars/html/reference/api.html :

file checker.py:

import polars as pl

@pl.api.register_expr_namespace("greetings")
class Greetings:
    def __init__(self, expr: pl.Expr):
        self._expr = expr

    def hello(self) -> pl.Expr:
        return (pl.lit("Hello ") + self._expr).alias("hi there")

    def goodbye(self) -> pl.Expr:
        return (pl.lit("Sayōnara ") + self._expr).alias("bye")


print(pl.DataFrame(data=["world", "world!", "world!!"]).select(
    [
        pl.all().greetings.hello(), # type: ignore[attr-defined]
        pl.all().greetings.goodbye(),
    ]
))
% mypy checker.py
checker.py:19: error: "Expr" has no attribute "greetings"  [attr-defined]
Found 1 error in 1 file (checked 1 source file
% mypy --version
mypy 1.8.0 (compiled: yes)
% pyright checker.py
/path/to/checker.py
  /apth/to/checker.py:19:18 - error: Cannot access member "greetings" for type "Expr"
    Member "greetings" is unknown (reportGeneralTypeIssues)
1 error, 0 warnings, 0 informations
% pyright --version
pyright 1.1.343

Solution

  • Polars API registration is non-compliant with Python's typing system. First and foremost, Polars can help by providing a dynamic attribute access typing hint. Without them doing this, only type-checkers with modding or plugin support can allow you to get around this.

    Polars classes with dynamically-registerable namespaces need to define __getattr__

    Type checkers (both mypy and Pyright) require error suppression because polars.expr.expr.Expr does not have the dynamic attribute accessor defined. The presence of something like this:

    class Expr:
        if typing.TYPE_CHECKING:
            def __getattr__(self, attr_name: str, /) -> typing.Any: ...
    

    is sufficient to silence type checkers about dynamic attribute access (see mypy Playground, Pyright Playground, also typing documentation >> Type System Reference >> Type Stubs >> Attribute Access - this equally applies to inline types in source files, such as Polars' .py files). I would suggest filing an issue with the Polars devs to ask them to add a dynamic __getattr__ for typing purposes only.

    Pyright requires per-file or per-line controls

    For Pyright, these errors can't be suppressed except using a per-line # type: ignores, or bunch of file-level type controls like # pyright: reportUnknownMemberType=none, reportGeneralTypeIssues=none. Pyright's BDFL has previously expressed reluctance for plugin support, so it is unmoddable until someone writes a Typing PEP to propose a standard for somehow dynamically registering namespaces.

    mypy can offer a statically-typed solution in the form of a mypy plugin

    It'd be great if Polars supported dynamic namespace registration, at least in mypy, through a mypy plugin - I suggest pushing for this. You can use the following as an inspiration, which provides full static type checking of the registered namespaces.

    mypy static type checking result

    This is the kind of result you'd want to get:

    import polars as pl
    
    @pl.api.register_expr_namespace("greetings")
    class Greetings:
        def __init__(self, expr: pl.Expr):
            self._expr = expr
    
        def hello(self) -> pl.Expr:
            return (pl.lit("Hello ") + self._expr).alias("hi there")
    
        def goodbye(self) -> pl.Expr:
            return (pl.lit("Sayōnara ") + self._expr).alias("bye")
    
    >>> print(
    ...     pl.DataFrame(data=["world", "world!", "world!!"]).select(
    ...         [
    ...             pl.all().greetings.hello(),
    ...             pl.all().greetings.goodbye(1),  # mypy: Too many arguments for "goodbye" of "Greetings"  [call-arg]
    ...             pl.all().asdfjkl                # mypy: `polars.expr.expr.Expr` object has no attribute `asdfjkl`  [misc]
    ...         ]
    ...     )
    ... )
    ...
    

    Project structure

    project/
      mypy.ini
      mypy_polars_plugin.py
      test.py
    

    Implementation

    Contents of mypy.ini

    [mypy]
    plugins = mypy_polars_plugin.py
    

    Contents of mypy_polars_plugin.py

    from __future__ import annotations
    
    import typing_extensions as t
    
    import mypy.nodes
    import mypy.plugin
    import mypy.plugins.common
    
    if t.TYPE_CHECKING:
        import collections.abc as cx
    
        import mypy.options
        import mypy.types
    
    STR___GETATTR___NAME: t.Final = "__getattr__"
    STR_POLARS_EXPR_MODULE_NAME: t.Final = "polars.expr.expr"
    STR_POLARS_EXPR_FULLNAME: t.Final = f"{STR_POLARS_EXPR_MODULE_NAME}.Expr"
    STR_POLARS_EXPR_REGISTER_EXPR_NAMESPACE_FULLNAME: t.Final = "polars.api.register_expr_namespace"
    
    def plugin(version: str) -> type[PolarsPlugin]:
        return PolarsPlugin
    
    class PolarsPlugin(mypy.plugin.Plugin):
    
        _polars_expr_namespace_name_to_type_dict: dict[str, mypy.types.Type]
    
        def __init__(self, options: mypy.options.Options) -> None:
            super().__init__(options)
            self._polars_expr_namespace_name_to_type_dict = {}
    
        @t.override
        def get_customize_class_mro_hook(
            self, fullname: str
        ) -> cx.Callable[[mypy.plugin.ClassDefContext], None] | None:
            """
            mypy requires the presence of `__getattr__` or `__getattribute__` for
            `get_attribute_hook` to work on dynamic attributes. This hook-getter adds
            `__getattr__` to the class definition of `polars.expr.expr.Expr`.
            """
    
            if fullname == STR_POLARS_EXPR_FULLNAME:
                return add_getattr
            return None
    
        @t.override
        def get_class_decorator_hook_2(
            self, fullname: str
        ) -> cx.Callable[[mypy.plugin.ClassDefContext], bool] | None:
            """
            Makes mypy recognise the class decorator factory
            `@polars.api.register_expr_namespace(...)` in the following context:
    
                @polars.api.register_expr_namespace(<namespace name>)
                class <Namespace>: ...
    
            Accumulates a mapping of a bunch of potential attributes to be accessible on
            instances of `polars.expr.expr.Expr`; the mapping has entries which look like
                `<namespace name>: <Namespace>`
            """
    
            if fullname == STR_POLARS_EXPR_REGISTER_EXPR_NAMESPACE_FULLNAME:
                return self.polars_expr_namespace_registering_hook
            return None
    
        @t.override
        def get_attribute_hook(
            self, fullname: str
        ) -> cx.Callable[[mypy.plugin.AttributeContext], mypy.types.Type] | None:
            """
            Makes mypy understand that, whenever an attribute is accessed from instances of
            `polars.expr.expr.Expr` and the attribute doesn't exist, reach for the
            attributes accumulated in the mapping through the actions of
            `get_class_decorator_hook_2`.
            """
    
            if fullname.startswith(f"{STR_POLARS_EXPR_FULLNAME}."):
                return self.polars_expr_attribute_hook
            return None
    
        def polars_expr_namespace_registering_hook(
            self, ctx: mypy.plugin.ClassDefContext
        ) -> bool:
            """
            Use the decorator factory `polars.api.register_expr_namespace(<namespace name>)`
            to register available dynamic attributes later accessed from instances of
            `polars.expr.expr.Expr`. Returns whether the class has enough information to be
            considered semantically analysed.
            """
    
            # Ensure that the class decorator expression looks like
            # `@polars.api.register_expr_namespace(<namespace name>)`
            namespace_arg: str | None
            if (
                (not isinstance(ctx.reason, mypy.nodes.CallExpr))
                or (len(ctx.reason.args) != 1)
                or (
                    (namespace_arg := ctx.api.parse_str_literal(ctx.reason.args[0])) is None
                )
            ):
                # If the decorator factory expression doesn't look valid, do an early
                # return.
                return True
    
            self._polars_expr_namespace_name_to_type_dict[
                namespace_arg
            ] = ctx.api.named_type(ctx.cls.fullname)
    
            return True
    
        def polars_expr_attribute_hook(
            self, ctx: mypy.plugin.AttributeContext
        ) -> mypy.types.Type:
            """
            Reaches for registered namespaces when accessing attributes on instances of
            `polars.expr.expr.Expr`. Shows an error when the attribute doesn't exist.
            """
    
            assert isinstance(ctx.context, mypy.nodes.MemberExpr)
            attr_name: str = ctx.context.name
            namespace_type: mypy.types.Type | None = (
                self._polars_expr_namespace_name_to_type_dict.get(attr_name)
            )
            if namespace_type is not None:
                return namespace_type
            else:
                ctx.api.fail(
                    f"`{STR_POLARS_EXPR_FULLNAME}` object has no attribute `{attr_name}`",
                    ctx.context,
                )
                return mypy.types.AnyType(mypy.types.TypeOfAny.from_error)
    
    
    def add_getattr(ctx: mypy.plugin.ClassDefContext) -> None:
    
        mypy.plugins.common.add_method_to_class(
            ctx.api,
            cls=ctx.cls,
            name=STR___GETATTR___NAME,
            args=[
                mypy.nodes.Argument(
                    variable=mypy.nodes.Var(
                        name="attr_name", type=ctx.api.named_type("builtins.str")
                    ),
                    type_annotation=ctx.api.named_type("builtins.str"),
                    initializer=None,
                    kind=mypy.nodes.ArgKind.ARG_POS,
                    pos_only=True,
                )
            ],
            return_type=mypy.types.AnyType(mypy.types.TypeOfAny.implementation_artifact),
            self_type=ctx.api.named_type(STR_POLARS_EXPR_FULLNAME),
        )
    

    Contents of test.py

    import polars as pl
    
    
    @pl.api.register_expr_namespace("greetings")
    class Greetings:
        def __init__(self, expr: pl.Expr):
            self._expr = expr
    
        def hello(self) -> pl.Expr:
            return (pl.lit("Hello ") + self._expr).alias("hi there")
    
        def goodbye(self) -> pl.Expr:
            return (pl.lit("Sayōnara ") + self._expr).alias("bye")
    
    
    print(
        pl.DataFrame(data=["world", "world!", "world!!"]).select(
            [
                pl.all().greetings.hello(),
                pl.all().greetings.goodbye(1),  # mypy: Too many arguments for "goodbye" of "Greetings"  [call-arg]
                pl.all().asdfjkl                # mypy: `polars.expr.expr.Expr` object has no attribute `asdfjkl`
            ]
        )
    )