Search code examples
pythonpython-typingmypy

Silence mypy arg-type error when using stategy pattern


Minimal example:

from typing import overload, TypeVar, Generic

class EventV1: pass
class EventV2: pass

class DataGathererV1:
    def process(self, event: EventV1): pass
    def process2(self, event: EventV1): pass

class DataGathererV2:
    def process(self, event: EventV2): pass
    def process2(self, event: EventV2): pass


class Dispatcher:
    def __init__(self):
        self.worker_v1: DataGathererV1 = DataGathererV1()
        self.worker_v2: DataGathererV2 = DataGathererV2()

    def dispatch(self, event: EventV1 | EventV2):
        handler: DataGathererV1 | DataGathererV2 = self.worker_v1 if isinstance(event, EventV1) else self.worker_v2

        handler.process(event)
        # Common logic

        handler.process2(event)
        # Common logic

        handler.process(event)
        # etc...

In the code above, I'm using a sort of "strategy" pattern to process events. I would like to avoid splitting my dispatch method in two methods as I don't think it makes sense and will generate code duplication.

Mypy gives me the following errors and I don't know how to properly type my code to avoid such errors.

example.py:36: error: Argument 1 to "process" of "DataGathererV1" has incompatible type "EventV1 | EventV2"; expected "EventV1"  [arg-type]
example.py:36: error: Argument 1 to "process" of "DataGathererV2" has incompatible type "EventV1 | EventV2"; expected "EventV2"  [arg-type]
example.py:40: error: Argument 1 to "process2" of "DataGathererV1" has incompatible type "EventV1 | EventV2"; expected "EventV1"  [arg-type]
example.py:40: error: Argument 1 to "process2" of "DataGathererV2" has incompatible type "EventV1 | EventV2"; expected "EventV2"  [arg-type]
example.py:44: error: Argument 1 to "process" of "DataGathererV1" has incompatible type "EventV1 | EventV2"; expected "EventV1"  [arg-type]
example.py:44: error: Argument 1 to "process" of "DataGathererV2" has incompatible type "EventV1 | EventV2"; expected "EventV2"  [arg-type]

I tried to add an overload function to retrieve both my handler and my event at the same time to force the type to be "linked" (see below), however it did not solve my problem.

@overload
def __handler(self, event: EventV1) -> tuple[DataGathererV1, EventV1]:
    ...

@overload
def __handler(self, event: EventV2) -> tuple[DataGathererV2, EventV2]:
    ...

def __handler(self, event: EventV1 | EventV2) -> tuple[DataGathererV1 | DataGathererV2, EventV1 | EventV2]:
    return (
        self.worker_v1 if isinstance(event, EventV1) else self.worker_v2,
        event
    )

My "dream" solution to indicate to mypy that the type of event and the type of handler are "linked" together.

What I would like to avoid is to set the input type of my event in each DataGathererVX to be EventV1 | EventV2 and add an assert isinstance(event, EventVX) at the beginning of each method as my aim is to have errors when these methods are called with the incorrect type of event.


Solution

  • The issue here is that I think there are two different types being conflated. Consider the protocol:

    T = TypeVar("T", contravariant=True)
    
    class DataGatherer(Generic[T], Protocol):
        def process(self, event: T):
            pass
        def process2(self, event: T):
            pass
    

    We can say DataGathererV1 is a subtype of DataGatherer[EventV1], and DataGathererV2 is a subtype of DataGatherer[EventV2]. So far so good.

    The type we need in the dispatch method however, is actually DataGatherer[EventV1 | EventV2] (as this takes both arguments). Since:

    DataGatherer[EventV1 | EventV2] != DataGatherer[EventV1] | DataGatherer[EventV2]
    

    ...we have an issue.

    A Solution

    By using a cast, we can convert handler to the type we desire:

    from typing import Protocol, overload, TypeVar, Generic, cast
    
    class EventV1:
        pass
    class EventV2:
        pass
    
    T = TypeVar("T", contravariant=True)
    
    class DataGatherer(Generic[T], Protocol):
        def process(self, event: T):
            pass
        def process2(self, event: T):
            pass
    
    class DataGathererV1(DataGatherer[EventV1]):
        def process(self, event: EventV1):
            pass
        def process2(self, event: EventV1):
            pass
    
    class DataGathererV2(DataGatherer[EventV2]):
        def process(self, event: EventV2):
            pass
        def process2(self, event: EventV2):
            pass
    
    class Dispatcher:
        def __init__(self):
            self.worker_v1 = DataGathererV1()
            self.worker_v2 = DataGathererV2()
    
        def dispatch(self, event: EventV1 | EventV2):
            handler = cast(
                DataGatherer[EventV1 | EventV2],
                self.worker_v1 if isinstance(event, EventV1) else self.worker_v2
            )
            handler.process(event)
            # Common logic
    
            handler.process2(event)
            # Common logic
    
            handler.process(event)
            # etc...
    

    You could also use a slightly more generic type for worker_v1 and worker_v2 to circumvent this (although less explicitly):

    
    class Dispatcher:
        def __init__(self):
            self.worker_v1: DataGatherer = DataGathererV1()
            self.worker_v2: DataGatherer = DataGathererV2()
    
        def dispatch(self, event: EventV1 | EventV2):
            handler: DataGatherer[EventV1 | EventV2] = self.worker_v1 if isinstance(event, EventV1) else self.worker_v2
    
            handler.process(event)
            # Common logic
    
            handler.process2(event)
            # Common logic
    
            handler.process(event)
            # etc...
    
    

    You could also split it into two if blocks, as previously suggested, but otherwise I'm not sure there's any other way to go about it. Hope this is useful!