Search code examples

Typing: How to consider class arguments wrapped with partial?

I have a class MyClass that expects a class parent_cls as an argument which fullfills the interface ParentInterface.

ChildA implements/extends from ParentInterface. Since inside of MyClass, parent_cls is instantiated with the argument a and b, the addtional argument c of ChildA is partially instantiated outside.

In principle this does run in Python. However I get a type warning from Pycharm: enter image description here

Any ideas how to fix that type warning?

from abc import ABC, abstractmethod
from functools import partial
from typing import Type, Optional, Callable, cast

class ParentInterface(ABC):
    def __init__(self, a: int, b: int):
        self.a = a
        self.b = b

    def do_something(self):

class ChildA(ParentInterface):
    def __init__(self, a: int, b: int, c: str):
        super().__init__(a, b)
        self.c = c

    def do_something(self):
        print('I am ChildA')

# update 1
class ChildB(ParentInterface):
    def __init__(self, a: int, b: int):
        super().__init__(a, b)

    def do_something(self):
        print('I am ChildB')

class MyClass:
    def __init__(self, parent_cls: Type[ParentInterface]):
        self.parent = parent_cls(3, 4)

# alternative
# class MyClass:
#     def __init__(self, parent_cls: Callable[[int, int], ParentInterface]):
#         self.parent = parent_cls(3, 4)

def typed_partial(cls, *args, **kwargs):
    return cast(Type[cls], partial(cls, *args, **kwargs))

# original code
# child_a_cls = partial(ChildA, c='some string')
# solution
child_a_cls = typed_partial(ChildA, c='some string')

my_class_with_childa = MyClass(parent_cls=child_a_cls)
my_class_with_childb = MyClass(parent_cls=ChildB)


  • You can use typing.cast to force the type checker assume the indicated type:

    child_a_cls = cast(
        partial(ChildA, c='some string')

    During runtime this function is a no-op (though still a function call). See also PEP 484.

    If you do this often, you could also move it to a separate function:

    def typed_partial(cls, *args, **kwargs):
        return cast(Type[cls], partial(cls, *args, **kwargs))