Search code examples
pythonclasstypessubclassfactory

Python can't infer static type of subclass when using a handler factory


I would like to use a general HandlerFactory class like the one described here (see Solution 2: Metaprogramming).

Let me use an example:

Suppose we have the following classes:

class Person:
    name: str

@PersonHandlerFactory.register
class Mark(Person):
    name = "Mark"
    job = "scientist"

@PersonHandlerFactory.register
class Charles(Person):
    name = "Charles"
    hobby = "football"

You may have noticed that the subclasses contain a decorator. This decorator is used to register these classes into the following PersonHandlerFactory class, which returns a specific class given the person name:

from typing import Dict, Type

class PersonHandlerFactory:
    handlers: Dict[str, Type[Person]] = {}

    @classmethod
    def register(cls, handler_cls: Type[Person]):
        cls.handlers[handler_cls.name] = handler_cls
        return handler_cls

    @classmethod
    def get(cls, name: str):
        return cls.handlers[name]

As you can see, I used the type Type[Person], because I want this method to be used for any subclass of Person.

But somehow Python interprets the static type of an instance of any subclass as the class Parent:

mark = Mark()  # Static type of 'mark' is 'Person' :S
print(mark.job)  # Python can't find the type of 'job'

I don't want to change Type[Person] for Mark | Charles because the class PersonHandlerFactory should not know about the subclasses of Person.


Solution

  • Use a bound TypeVar, this allows the correct subclass to be inferred.

    from typing import Dict, Type, TypeVar
    
    
    class Person:
        name: str
    
    
    T = TypeVar("T", bound=Person)
    
    
    class PersonHandlerFactory:
        handlers: Dict[str, Type[Person]] = {}
    
        @classmethod
        def register(cls, handler_cls: Type[T]) -> Type[T]:
            cls.handlers[handler_cls.name] = handler_cls
            return handler_cls
    
        @classmethod
        def get(cls, name: str):
            return cls.handlers[name]
    
    
    @PersonHandlerFactory.register
    class Mark(Person):
        name = "Mark"
        job = "scientist"
    
    
    mark = Mark()  # Static type of 'mark' is 'Mark' :)
    

    The class returned by PersonHandlerFactory.get will however always be inferred as Type[Person]