I have a pretty small but versatile (maybe too much) class in Python that effectively has two generic types that are constrained together. My code (yet poorly typed) code should show my intent:
class ApplyTo:
def __init__(
self,
*transforms: Callable[..., Any],
to: Any | Sequence[Any],
dispatch: Literal['separate', 'joint'] = 'separate',
):
self._transform = TransformsPipeline(*transforms)
self._to = to if isinstance(to, Sequence) else [to]
self._dispatch = dispatch
def __call__(self, data: MutableSequence | MutableMapping):
if self._dispatch == 'separate':
for key in self._to:
data[key] = self._transform(data[key])
return data
if self._dispatch == 'joint':
args = [data[key] for key in self._to]
transformed = self._transform(*args)
for output, key in zip(transformed, self._to):
data[key] = output
return data
assert False
I have double checked that this works in runtime and is pretty straightforward, but the typing is really horrendous.
So the idea is that when we set up to
to be an int
, then data
should be MutableSequence | MutableMapping[int, Any]
; when to
is Hashable
, then data
should be MutableMapping[Hashable or whatever type of to is, Any]
. I know that int
is Hashable
which doesn't make this easier.
My very poor attempt at typing this looks like this:
T = TypeVar('T', bound=Hashable | int)
C = TypeVar('C', bound=MutableMapping[T, Any] | MutableSequence)
class ApplyTo(Generic[C, T]):
def __init__(
self,
*transforms: Callable[..., Any],
to: T | Sequence[T],
dispatch: Literal['separate', 'joint'] = 'separate',
):
self._transform = TransformsPipeline(*transforms)
self._to = to if isinstance(to, Sequence) else [to]
self._dispatch = dispatch
def __call__(self, data: C):
if self._dispatch == 'separate':
for key in self._to:
data[key] = self._transform(input[key])
return input
if self._dispatch == 'joint':
args = [data[key] for key in self._to]
transformed = self._transform(*args)
for output, key in zip(transformed, self._to):
data[key] = output
return data
assert False
Which makes mypy complain (no surprise):
error: Type variable "task_driven_sr.transforms.generic.T" is unbound [valid-type]
note: (Hint: Use "Generic[T]" or "Protocol[T]" base class to bind "T" inside a class)
note: (Hint: Use "T" in function signature to bind "T" inside a function)
Is there any way to type hint this correctly and somehow bound the type of to
and data
together? Maybe my approach is flawed and I have reached a dead end.
Edit: Fixed some code inside 'joint' branch of dispatch
, it was not connected with the typing in question, nevertheless I have made it as it should be.
You can replace the union of MutableMapping
and MutableSequence
with a Protocol
:
import typing
DispatchType = typing.Literal['separate', 'joint']
# `P` must be declared with `contravariant=True`, otherwise it errors with
# 'Invariant type variable "P" used in protocol where contravariant one is expected'
K = typing.TypeVar('K', contravariant=True)
class Indexable(typing.Protocol[K]):
def __getitem__(self, key: K):
pass
def __setitem__(self, key: K, value: typing.Any):
pass
# Accepts only hashable types (including `int`s)
H = typing.TypeVar('H', bound=typing.Hashable)
class ApplyTo(typing.Generic[H]):
_to: typing.Sequence[H]
_dispatch: DispatchType
_transform: typing.Callable[..., typing.Any] # TODO Initialize `_transform`
def __init__(self, to: typing.Sequence[H] | H, dispatch: DispatchType = 'separate') -> None:
self._dispatch = dispatch
self._to = to if isinstance(to, typing.Sequence) else [to]
def __call__(self, data: Indexable[H]) -> typing.Any:
if self._dispatch == 'separate':
for key in self._to:
data[key] = self._transform(data[key])
return data
if self._dispatch == 'joint':
args = [data[key] for key in self._to]
return self._transform(*args)
assert False
Usage:
def main() -> None:
r0 = ApplyTo(to=0)([1, 2, 3])
# typechecks
r0 = ApplyTo(to=0)({1: 'a', 2: 'b', 3: 'c'})
# typechecks
r1 = ApplyTo(to='a')(['b', 'c', 'd'])
# does not typecheck: Argument 1 to "__call__" of "Applier" has incompatible type "list[str]"; expected "Indexable[str]"
r1 = ApplyTo(to='a')({'b': 1, 'c': 2, 'd': 3})
# typechecks