Search code examples
pythonpython-typingmypypython-decorators

Python type hints for class decorators with self references


My ultimate goal is to write a system to easily record function calls (in particular for class methods).

I started by writing a class Loggable with a wrapper method that allows me to decorate subclasses methods and record their calls

Param = ParamSpec("Param")
RetType = TypeVar("RetType")


CountType = TypeVar("CountType", bound="FunctionCount")


class FunctionCount(Generic[CountType]):
    def __init__(self, count_dict: dict[str, int]) -> None:
        self.count_dict = count_dict

    @staticmethod
    def count(
        func: Callable[Concatenate[CountType, Param], RetType],
    ) -> Callable[Concatenate[CountType, Param], RetType]:
        def wrapper(
            self: CountType, *args: Param.args, **kwargs: Param.kwargs
        ) -> RetType:
            function_name = f"{self.__class__.__name__}.{func.__name__}"
            if function_name not in self.count_dict:
                self.count_dict[function_name] = 0
            self.count_dict[function_name] += 1

            return func(self, *args, **kwargs)

        return wrapper

Now I can write subclasses and record their calls:

class A(FunctionCount):
    def __init__(self, count_dict: dict[str, int]) -> None:
        super().__init__(count_dict)

    @FunctionCount.count
    def func(self) -> None:
        pass

    @FunctionCount.count
    def func2(self) -> None:
        pass


count_dict: dict[str, int] = {}

a = A(count_dict)

a.func()
a.func()
a.func2()

print(count_dict)
assert count_dict == {"A.func": 2, "A.func2": 1}

It works really well and I was glad. But then I thought it would be nice to have custom names for the methods so I changed the wrapper into a decorator

class FunctionCount(Generic[CountType]):
    def __init__(self, count_dict: dict[str, int]) -> None:
        self.count_dict = count_dict

    @staticmethod
    def count(
        f_name: str | None = None,
    ) -> Callable[
        [Callable[Concatenate[CountType, Param], RetType]],
        Callable[Concatenate[CountType, Param], RetType],
    ]:
        def decorator(
            func: Callable[Concatenate[CountType, Param], RetType],
        ) -> Callable[Concatenate[CountType, Param], RetType]:
            def wrapper(
                self: CountType, *args: Param.args, **kwargs: Param.kwargs
            ) -> RetType:
                function_name = f_name or f"{self.__class__.__name__}.{func.__name__}"
                if function_name not in self.count_dict:
                    self.count_dict[function_name] = 0
                self.count_dict[function_name] += 1

                return func(self, *args, **kwargs)

            return wrapper

        return decorator

Then I just had to change the decorator calls

class A(FunctionCount):
    def __init__(self, count_dict: dict[str, int]) -> None:
        super().__init__(count_dict)

    @FunctionCount.count()
    def func(self) -> None:
        pass

    @FunctionCount.count("custom_name")
    def func2(self) -> None:
        pass

a.func()
a.func()
a.func2()

print(count_dict)
assert count_dict == {"A.func": 2, "custom_name": 1}

This scripts also works very well but now mypy is giving me hard times. When I call the a.func method, I get the following mypy error:

Invalid self argument "A" to attribute function "func" with type "Callable[[Never], None]" mypy(misc)

I guess using a decorator instead of a wrapper caused this error but I can't understand why and what should I do to correct it.

Does someone know how to have a correctly typed decorator this way ?


Solution

  • I think this should indeed be an answer instead of a comment.

    Your class shouldn't be generic in CountType - count method is generic by design, but not the class itself. You try to spell "a staticmethod that works on some callable with first argument being a subtype of FunctionCount", not "a class that can only apply its staticmethod to a callable with...", right? Then it's the staticmethod itself what should be generic, not the class! Compare with the following:

    from typing import ParamSpec, Callable, TypeVar, Concatenate
    
    Param = ParamSpec("Param")
    RetType = TypeVar("RetType")
    CountType = TypeVar("CountType", bound="FunctionCount")
    
    class FunctionCount:
        def __init__(self, count_dict: dict[str, int]) -> None:
            self.count_dict = count_dict
    
        @staticmethod
        def count(
            f_name: str | None = None,
        ) -> Callable[
            [Callable[Concatenate[CountType, Param], RetType]],
            Callable[Concatenate[CountType, Param], RetType],
        ]:
            def decorator(
                func: Callable[Concatenate[CountType, Param], RetType],
            ) -> Callable[Concatenate[CountType, Param], RetType]:
                def wrapper(
                    self: CountType, /, *args: Param.args, **kwargs: Param.kwargs
                ) -> RetType:
                    function_name = f_name or f"{self.__class__.__name__}.{func.__name__}"
                    if function_name not in self.count_dict:
                        self.count_dict[function_name] = 0
                    self.count_dict[function_name] += 1
    
                    return func(self, *args, **kwargs)
    
                return wrapper
                
            return decorator
    

    This passes mypy --strict (mypy playground, pyright playground).

    As an aside, run mypy --strict when you encounter some mypy errors that contain Never (or just something "weird" you struggle to understand) - that may point out other places where you're doing something mypy does not like. mypy --strict produces a bunch of errors on your original code, one of the most important ones is about missing generic argument for the subclass (see here).