Search code examples
pythongenericspython-typing

How to create a factory function that preserves generic type information in Python?


I have the following function that returns a class with a generic type parameter:

T = TypeVar('T')

def create_property(event_bus: EventBus):
    class Property(Generic[T]):
        def __init__(self, validator: Callable[[T], bool]):
            self._validator = validator

        def __set_name__(self, obj: Any, name: str):
            self.name = name

        def __get__(self, obj: Any, type: Any) -> T:
            return obj.__dict__.get(self.name)

        def __set__(self, obj: Any, value: T):
            if not self._validator(value):
                raise ValueError("Invalid value")
            obj.__dict__[self.name] = value
            event_bus.publish()

    return Property

What I'm trying to do here is to create a class that has an EventBus bound to it so that I don't have to pass it as a constructor parameter.

My problem with the above code is that this will resolve to Property[Any] instead of Property[T] so the generic is lost somewhere along the way. How can I fix this function to preserve the generic?


Solution

  • Implementation notes:

    1. Property class must be global to use it in return type annotation of create_property.
    2. Inside create_property, subclass and bind event_bus.
    3. Instantiate new Property class as separate variable prop, otherwise type hints are lost.
    4. Instantiate descriptor object x, its type is determined by validator type.
    from typing import Any, Callable
    
    class EventBus:
        def publish(self):
            ...
    
    class Property[T]:
        def __init__(self, validator: Callable[[T], bool]):
            self._validator = validator
    
        def __set_name__(self, obj: Any, name: str):
            self.name = name
    
        def __get__(self, obj: Any, type: Any) -> T:
            return obj.__dict__.get(self.name)
    
        def __set__(self, obj: Any, value: T):
            if not self._validator(value):
                raise ValueError("Invalid value")
            obj.__dict__[self.name] = value
    
    
    def create_property[T](event_bus: EventBus) -> Callable[[Callable[[T], bool]], Property[T]]:
        class _Prop(Property[T]):
            def __set__(self, obj: Any, value: T):
                super().__set__(obj, value)
                event_bus.publish()
    
        return _Prop
    
    prop = create_property(event_bus=EventBus())
    
    def validator(value: int) -> bool:
        return True
    
    # pyright: strict
    
    x = prop(validator)
    

    Verification

    $ pyright
    0 errors, 0 warnings, 0 informations