Search code examples
pythongenericsmypypython-typing

How do I safely type a function that accepts a generic container class?


from __future__ import annotations

import logging
from datetime import datetime, UTC
from typing import Any, Generic, Self, Protocol, TypeVar

from pydantic import AwareDatetime, BaseModel

logger = logging.getLogger(__name__)
EventDataT_co = TypeVar('EventDataT_co')


class Event(BaseModel, Generic[EventDataT_co]):
    raised_at: AwareDatetime
    data: tuple[EventDataT_co, ...]

    @classmethod
    def from_data(cls, *data: EventDataT_co) -> Self:
        return cls(raised_at=datetime.now(UTC), data=data)


class EventRepository(Protocol):
    def save_events_unsafe(self, *events: Event[Any]) -> None:  # same as `*events: Event[Any]`
        """not type-safe (someone might make unsafe assumption on `.data`)"""

    def save_events_safe(self, *events: Event[object]) -> None:
        """Only allowed if Event is covariant, but then I can't have a custom constructor"""

    def save_events_gen(self, *events: Event[EventDataT_co]) -> None:
        """It implies this is a generic function, when it's more `.data`-agnostic"""


class InheritedEventRepository(EventRepository):
    def save_events_unsafe(self, *events: Event[Any]) -> None:
        logger.info(str([data.id for event in events for data in event.data]))  # Type-checker says ok, dev made an uncaught mistake

    def save_events_safe(self, *events: Event[object]) -> None:
        logger.info(str([data.id for event in events for data in event.data]))  # Type-checker says error: "object" has no attribute "id"  [attr-defined]

    def save_events_gen(self, *events: Event[EventDataT_co]) -> None:
        logger.info(str([data.id for event in events for data in event.data]))
        # Type-checker error: "EventDataT_co" has no attribute "id"  [attr-defined] if covariant
        # No error if invariant


event_1 = Event.from_data(1)
event_2 = Event.from_data('foo')
from typing import reveal_type
print(reveal_type(event_1))  # Event[builtins.int]
print(reveal_type(event_2))  # Event[builtins.str]
InheritedEventRepository().save_events_unsafe(event_1, event_2)
InheritedEventRepository().save_events_safe(event_1, event_2)
# error: Argument 1 to "save_events_safe" of "InheritedEventRepository" has incompatible type "Event[int]"; expected "Event[object]"  [arg-type] if invariant
# No error if covariant
InheritedEventRepository().save_events_gen(event_1, event_2)
# error: Cannot infer type argument 1 of "save_events_gen" of "InheritedEventRepository"  [misc] if invariant
# No error if covariant

Mypy throws:

error: Cannot use a covariant type variable as a parameter [misc]

on the def from_data(...) constructor line. I don't want to make Event covariant but it seems the only way to allow save_events_safe to be accepted by mypy? I don't want to use Event[Any] in save_events because save_events will be subclassed (I wouldn't want the dev that inherits from it to have no type-checking safety). Finally, save_events_gen(self, *events: Event[EventDataT_co]) has the same problem as save_events_safe (covariant only)

That means I'm either stuck with a covariant version that doesn't allow a custom constructor (also, Event really shouldn't be used covariantly outside of this case), or an invariant one that forces me to use Any for agnostic functions. How do I solve this ?


Solution

  • This is something I myself have faced before, let's state the facts:

    • Event is Generic - makes sense, it's some class that has generic properties to be overwritten by library users
    • We want a function that uses Event, but that doesn't care about what it's properties are outside of those that are generic.

    This representation:

    def save_events(*events: Event[T]): ...
    

    Satisfies all of these conditions - being "data-agnostic" and being "generic" are effectively the same thing, if you think of generic as "independent of the properties of some internal section".

    But what about the covariance issue?

    The below code segment:

    from __future__ import annotations
    from datetime import datetime, UTC
    from typing import  Generic, Self, TypeVar
    from pydantic import AwareDatetime, BaseModel
    
    
    EventDataT = TypeVar('EventDataT')
    
    class Event(BaseModel, Generic[EventDataT]):
        raised_at: AwareDatetime
        data: tuple[EventDataT, ...]
    
        @classmethod
        def from_data(cls, *data: EventDataT) -> Self:  
            return cls(raised_at=datetime.now(UTC), data=data)
    
    
    EventDataT_co = TypeVar('EventDataT_co', covariant=True)
    
    
    def save_events(*events: Event[EventDataT_co]): ...
    
    class Test:
        pass
    
    save_events(Event[Test].from_data(Test()))
    

    ...type checks just fine for me in both pyright and mypy strict mode. By introducing a new type var we can make one co-variant and the other not co-variant. Hope this helps!