Search code examples
pythonmypypython-typing

In Python typing is there a way to specify combinations of allowed generic types that are dependent?


I have a pretty small but versatile (maybe too much) class in Python that effectively has two generic types that are constrained together. My code (yet poorly typed) code should show my intent:

class ApplyTo:
    def __init__(
        self,
        *transforms: Callable[..., Any],
        to: Any | Sequence[Any],
        dispatch: Literal['separate', 'joint'] = 'separate',
    ):
        self._transform = TransformsPipeline(*transforms)
        self._to = to if isinstance(to, Sequence) else [to]
        self._dispatch = dispatch

    def __call__(self, data: MutableSequence | MutableMapping):
        if self._dispatch == 'separate':
            for key in self._to:
                data[key] = self._transform(data[key])
            return data

        if self._dispatch == 'joint':
            args = [data[key] for key in self._to]
            transformed = self._transform(*args)
            for output, key in zip(transformed, self._to):
                data[key] = output
            return data

        assert False

I have double checked that this works in runtime and is pretty straightforward, but the typing is really horrendous.

So the idea is that when we set up to to be an int, then data should be MutableSequence | MutableMapping[int, Any]; when to is Hashable, then data should be MutableMapping[Hashable or whatever type of to is, Any]. I know that int is Hashable which doesn't make this easier.

My very poor attempt at typing this looks like this:

T = TypeVar('T', bound=Hashable | int)
C = TypeVar('C', bound=MutableMapping[T, Any] | MutableSequence)


class ApplyTo(Generic[C, T]):
    def __init__(
        self,
        *transforms: Callable[..., Any],
        to: T | Sequence[T],
        dispatch: Literal['separate', 'joint'] = 'separate',
    ):
        self._transform = TransformsPipeline(*transforms)
        self._to = to if isinstance(to, Sequence) else [to]
        self._dispatch = dispatch

    def __call__(self, data: C):
        if self._dispatch == 'separate':
            for key in self._to:
                data[key] = self._transform(input[key])
            return input

        if self._dispatch == 'joint':
            args = [data[key] for key in self._to]
            transformed = self._transform(*args)
            for output, key in zip(transformed, self._to):
                data[key] = output
            return data

        assert False

Which makes mypy complain (no surprise):

error: Type variable "task_driven_sr.transforms.generic.T" is unbound  [valid-type]
note: (Hint: Use "Generic[T]" or "Protocol[T]" base class to bind "T" inside a class)
note: (Hint: Use "T" in function signature to bind "T" inside a function)

Is there any way to type hint this correctly and somehow bound the type of to and data together? Maybe my approach is flawed and I have reached a dead end.

Edit: Fixed some code inside 'joint' branch of dispatch, it was not connected with the typing in question, nevertheless I have made it as it should be.


Solution

  • You can replace the union of MutableMapping and MutableSequence with a Protocol:

    import typing
    
    DispatchType = typing.Literal['separate', 'joint']
    
    # `P` must be declared with `contravariant=True`, otherwise it errors with
    # 'Invariant type variable "P" used in protocol where contravariant one is expected'
    K = typing.TypeVar('K', contravariant=True)
    class Indexable(typing.Protocol[K]):
        def __getitem__(self, key: K):
            pass
        
        def __setitem__(self, key: K, value: typing.Any):
            pass
    
    # Accepts only hashable types (including `int`s)
    H = typing.TypeVar('H', bound=typing.Hashable)
    class ApplyTo(typing.Generic[H]):
        _to: typing.Sequence[H]
        _dispatch: DispatchType
        _transform: typing.Callable[..., typing.Any]  # TODO Initialize `_transform`
    
        def __init__(self, to: typing.Sequence[H] | H, dispatch: DispatchType = 'separate') -> None:
            self._dispatch = dispatch
            self._to = to if isinstance(to, typing.Sequence) else [to]
    
        def __call__(self, data: Indexable[H]) -> typing.Any:
            if self._dispatch == 'separate':
                for key in self._to:
                    data[key] = self._transform(data[key])
                return data
    
            if self._dispatch == 'joint':
                args = [data[key] for key in self._to]
                return self._transform(*args)
    
            assert False
    

    Usage:

    def main() -> None:
        r0 = ApplyTo(to=0)([1, 2, 3])
        # typechecks
        r0 = ApplyTo(to=0)({1: 'a', 2: 'b', 3: 'c'})
        # typechecks
    
        r1 = ApplyTo(to='a')(['b', 'c', 'd'])
        # does not typecheck: Argument 1 to "__call__" of "Applier" has incompatible type "list[str]"; expected "Indexable[str]"
        r1 = ApplyTo(to='a')({'b': 1, 'c': 2, 'd': 3}) 
        # typechecks