Search code examples
pythonpython-typingmypy

Type hints for mypy for a decorator over staticmethod/classmethod


I'm writing a helper for logger library that has decorators with specific logging for trace (debugging).

The code itself is correct (it is partially based on existing library), but i struggle to find how to make mypy accept types for it.

Question marks is where i have problems with types. Or maybe the problem is more general

For staticmethod:

def trace_static_method(_staticmethod: staticmethod) -> staticmethod:
    @wraps(_staticmethod.__func__)  # this generate mypy error for incorrect type
    def wrapper(*args: ???, **kwargs: ???) -> ???:
        return _log_trace(_staticmethod.__func__, *args, **kwargs)
    return staticmethod(wrapper)  # this generate mypy error for incorrect type

For classmethod:

def trace_class_method(_classmethod: classmethod) -> classmethod:
    @wraps(_classmethod.__func__)  # this generate mypy error for incorrect type
    def wrapper(_cls: ???, *args: ???, **kwargs: ???) -> ???:
        method = _classmethod.__get__(None, _cls)  # this generate mypy error for incorrect type
        return _log_trace(method, *args, **kwargs)
    return classmethod(wrapper)  # this generate mypy error for incorrect type

Log trace:

def _log_trace(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
    name = func.__qualname__
    module = func.__module__
    logger_ = logger.opt(depth=1)
    logger_.log("TRACE", "{}.{} CALL args={}, kwargs={}", module, name, args, kwargs)
    result = func(*args, **kwargs)
    logger_.log("TRACE", "{}.{} RETURN {}", module, name, result)
    return result

Working types for a simple function decorator:

def trace(func: Callable[P, T]) -> Callable[P, T]:
    @wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        return _log_trace(func, *args, **kwargs)
    return wrapper

EDIT: adding correct types for static method was actually pretty straightforward:

def trace_static_method(_staticmethod: staticmethod[P, T]) -> staticmethod[P, T]:

    @wraps(_staticmethod.__func__)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        return _log_trace(_staticmethod.__func__, *args, **kwargs)

    return staticmethod(wrapper)

Solution

  • Correct typings. Works with mypy 1.6.1.

    P = ParamSpec("P")
    T = TypeVar("T")
    R_co = TypeVar("R_co", covariant=True)
    
    def _log_trace(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
        """Log a function call and return."""
        trace_logger = logger.opt(depth=1).bind(
            trace_name=func.__qualname__,
            trace_module=func.__module__,
        )
        trace_logger.trace("CALL args={}, kwargs={}", args, kwargs)
        result = func(*args, **kwargs)
        trace_logger.trace("RETURN {}", result)
        return result
    
    def trace(func: Callable[P, T]) -> Callable[P, T]:
        """Trace a function."""
    
        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            return _log_trace(func, *args, **kwargs)
    
        return wrapper
    
    
    def trace_instance_method(func: Callable[P, T]) -> Callable[P, T]:
        """Trace a method."""
    
        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
            _self: object = args[0]
            method = func.__get__(_self, _self.__class__)
            return _log_trace(method, *args[1:], **kwargs)
    
        return wrapper
    
    
    def trace_static_method(
        _staticmethod: "staticmethod[P, R_co]",
    ) -> "staticmethod[P, R_co]":
        """Trace a method wrapped in @staticmethod."""
    
        @wraps(_staticmethod.__func__)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> R_co:
            return _log_trace(_staticmethod.__func__, *args, **kwargs)
    
        return staticmethod(wrapper)
    
    
    def trace_class_method(
        _classmethod: "classmethod[T, P, R_co]",
    ) -> "classmethod[T, P, R_co]":
        """Trace a method wrapped in @classmethod"""
    
        @wraps(_classmethod.__func__)
        def wrapper(_cls: type[T], *args: P.args, **kwargs: P.kwargs) -> R_co:
            method = _classmethod.__get__(None, _cls)
            return _log_trace(method, *args, **kwargs)
    
        return classmethod(wrapper)