Search code examples
pythongenericstypesmypypyright

How do I check that parameters of a variadic callable are all of a certain subclass?


This might be a tough one.

Suppose I have a type

JSON = Union[Mapping[str, "JSON"], Sequence["JSON"], str, int, float, bool, None]

And I have a function

def memoize[**P, T: JSON](fn: Callable[P,T]) -> Callable[P,T]:
    # ...make the function memoizable
    return wrapped_fn

How do I constrain the parameters of fn to all be subtypes of JSON? Alternatively, if this can't be done statically, how do I check this inside memoize before creating the wrapper?

I tried giving bounds to the ParamSpec variable **P, but it seems this isn't implemented yet. I also tried issubclass but this doesn't play nice with typehints.


Solution

  • There isn't a way of currently doing this, if fn has an arbitrary signature. I believe that the next best thing is generating errors at the call site (see Pyright playground):

    import collections.abc as cx
    import typing as t
    
    type JSON = cx.Mapping[str, JSON] | cx.Sequence[JSON] | str | int | float | bool | None
    
    class _JSONOnlyCallable(t.Protocol):
        def __call__(self, /, *args: JSON, **kwargs: JSON) -> JSON: ...
    
    def memoize[F: cx.Callable[..., t.Any]](fn: F, /) -> F | _JSONOnlyCallable:
        return fn
    
    @memoize
    def f(a: int, b: str, c: set[int]) -> str: ...
    
    >>> f(1, "", {1, 2})
                 ^^^^^^
    pyright: Argument of type "set[int]" cannot be assigned to parameter "args" of type "JSON" in function "__call__"