Search code examples
pythonmypypython-typing

Typing a function decorator with conditional output type, in Python


I have a set of functions which all accept a value named parameter, plus arbitrary other named parameters.

I have a decorator: lazy. Normally the decorated functions return as normal, but return a partial function if value is None.

How do I type-hint the decorator, whose output depends on the value input?

from functools import partial

def lazy(func):
    def wrapper(value=None, **kwargs):
        if value is not None:
            return func(value=value, **kwargs)
        else:
            return partial(func, **kwargs)
    return wrapper

@lazy
def test_multiply(*, value: float, multiplier: float) -> float:
    return value * multiplier

@lazy
def test_format(*, value: float, fmt: str) -> str:
    return fmt % value

print('test_multiply 5*2:', test_multiply(value=5, multiplier=2))
print('test_format 7.777 as .2f:', test_format(value=7.777, fmt='%.2f'))

func_mult_11 = test_multiply(multiplier=11)  # returns a partial function
print('Type of func_mult_11:', type(func_mult_11))
print('func_mult_11 5*11:', func_mult_11(value=5))

I'm using mypy and I've managed to get most of the way using mypy extensions, but haven't got the value typing working in wrapper:

from typing import Callable, TypeVar, ParamSpec, Any, Optional
from mypy_extensions import DefaultNamedArg, KwArg

R = TypeVar("R")
P = ParamSpec("P")

def lazy(func: Callable[P, R]) -> Callable[[DefaultNamedArg(float, 'value'), KwArg(Any)], Any]:
    def wrapper(value = None, **kwargs: P.kwargs) -> R | partial[R]:
        if value is not None:
            return func(value=value, **kwargs)
        else:
            return partial(func, **kwargs)
    return wrapper

How can I type value? And better still, can I do this without mypy extensions?


Solution

  • I see two possible options here. First is "more formally correct", but way too permissive, approach relying on partial hint:

    from __future__ import annotations
    
    from functools import partial
    from typing import Callable, TypeVar, ParamSpec, Any, Optional, Protocol, overload, Concatenate
    
    R = TypeVar("R")
    P = ParamSpec("P")
    
    class YourCallable(Protocol[P, R]):
        @overload
        def __call__(self, value: float, *args: P.args, **kwargs: P.kwargs) -> R: ...
        @overload
        def __call__(self, value: None = None, *args: P.args, **kwargs: P.kwargs) -> partial[R]: ...
    
    def lazy(func: Callable[Concatenate[float, P], R]) -> YourCallable[P, R]:
        def wrapper(value: float | None = None, *args: P.args, **kwargs: P.kwargs) -> R | partial[R]:
            if value is not None:
                return func(value, *args, **kwargs)
            else:
                if args:
                    raise ValueError("Lazy call must provide keyword arguments only")
                return partial(func, **kwargs)
        return wrapper  # type: ignore[return-value]
    
    @lazy
    def test_multiply(value: float, *, multiplier: float) -> float:
        return value * multiplier
    
    @lazy
    def test_format(value: float, *, fmt: str) -> str:
        return fmt % value
    
    print('test_multiply 5*2:', test_multiply(value=5, multiplier=2))
    print('test_format 7.777 as .2f:', test_format(value=7.777, fmt='%.2f'))
    
    func_mult_11 = test_multiply(multiplier=11)  # returns a partial function
    print('Type of func_mult_11:', type(func_mult_11))
    print('func_mult_11 5*11:', func_mult_11(value=5))
    func_mult_11(value=5, multiplier=5)  # OK
    func_mult_11(value='a')  # False negative: we want this to fail
    

    Last two calls show hat is good and bad about this approach. partial accepts any input arguments, so is not sufficiently safe. If you want to override the arguments provided to lazy callable initially, this is probably the best solution.

    Note that I slightly changed signatures of the input callables: without that you will not be able to use Concatenate. Note also that KwArg, DefaultNamedArg and company are all deprecated in favour of protocols. You cannot use paramspec with kwargs only, args must also be present. If you trust your type checker, it is fine to use kwarg-only callables, all unnamed calls will be rejected at the type checking phase.

    However, I have another alternative to share if you do not want to override default args passed to the initial callable, which is fully safe, but emits false positives if you try to.

    from __future__ import annotations
    
    from functools import partial
    from typing import Callable, TypeVar, ParamSpec, Any, Optional, Protocol, overload, Concatenate
    
    _R_co = TypeVar("_R_co", covariant=True)
    R = TypeVar("R")
    P = ParamSpec("P")
    
    class ValueOnlyCallable(Protocol[_R_co]):
        def __call__(self, value: float) -> _R_co: ...
        
    class YourCallableTooStrict(Protocol[P, _R_co]):
        @overload
        def __call__(self, value: float, *args: P.args, **kwargs: P.kwargs) -> _R_co: ...
        @overload
        def __call__(self, value: None = None, *args: P.args, **kwargs: P.kwargs) -> ValueOnlyCallable[_R_co]: ...
    
    
    def lazy_strict(func: Callable[Concatenate[float, P], R]) -> YourCallableTooStrict[P, R]:
        def wrapper(value: float | None = None, *args: P.args, **kwargs: P.kwargs) -> R | partial[R]:
            if value is not None:
                return func(value, *args, **kwargs)
            else:
                if args:
                    raise ValueError("Lazy call must provide keyword arguments only")
                return partial(func, **kwargs)
        return wrapper  # type: ignore[return-value]
    
    @lazy_strict
    def test_multiply_strict(value: float, *, multiplier: float) -> float:
        return value * multiplier
    
    @lazy_strict
    def test_format_strict(value: float, *, fmt: str) -> str:
        return fmt % value
    
    print('test_multiply 5*2:', test_multiply_strict(value=5, multiplier=2))
    print('test_format 7.777 as .2f:', test_format_strict(value=7.777, fmt='%.2f'))
    
    func_mult_11_strict = test_multiply_strict(multiplier=11)  # returns a partial function
    print('Type of func_mult_11:', type(func_mult_11_strict))
    print('func_mult_11 5*11:', func_mult_11_strict(value=5))
    func_mult_11_strict(value=5, multiplier=5)  # False positive: OK at runtime, but not allowed by mypy. E: Unexpected keyword argument "multiplier" for "__call__" of "ValueOnlyCallable"  [call-arg]
    func_mult_11_strict(value='a')  # Expected. E: Argument "value" to "__call__" of "ValueOnlyCallable" has incompatible type "str"; expected "float"  [arg-type]
    

    You can also mark value kw-only in ValueOnlyCallable definition if you'd like, I just don't think it is reasonable for a function with only one argument.

    You can compare both approaches in playground.

    If you do not want to use an ignore comment, the verbose option below should work. However, I do not think that verbosity is worth removing one ignore comment - it's up to you to decide.

    def lazy_strict(func: Callable[Concatenate[float, P], R]) -> YourCallableTooStrict[P, R]:
        @overload
        def wrapper(value: float, *args: P.args, **kwargs: P.kwargs) -> R: ...
        @overload
        def wrapper(value: None = None, *args: P.args, **kwargs: P.kwargs) -> ValueOnlyCallable[R]: ...
        def wrapper(value: float | None = None, *args: P.args, **kwargs: P.kwargs) -> R | ValueOnlyCallable[R]:
            if value is not None:
                return func(value, *args, **kwargs)
            else:
                if args:
                    raise ValueError("Lazy call must provide keyword arguments only")    
                return partial(func, **kwargs)
        return wrapper
    

    Here's also Pyright playground, because mypy failed to find a mistake in my original answer and Pyright did.