Search code examples
pythonmypy

Subclassing in python: restricting the signature of child class functions


Let's say I have a base class like


class Foo:
    def __init__(self):
        return

    def foo(self, a: str | int) -> float:
        raise NotImplementedError()

which gets inherited by


class FooChild(Foo):
    def __init__(self):
        return

    def foo(self, a: str) -> float:
        return 0.5

Now, mypy is complaining about the Liskov substitution principle (which I might have studied but surely I have forgot) and I understand I'm violating some grounded OOP rule here. At the end of the day I just want to specify a parent class whose functions may be of different types, leaving to the children the burden of dealing with the nitty-gritty details.

So ideally this function foo should accept a large union of types, which should then be specified by the inheriting child classes. I'd like to keep this as elegant (i.e. short) as possible.

Do you know any workaround?

I tried using the @overload decorator with multiple functions like:


class Foo:
    def __init__(self):
        return
    
    @overload
    def fun(a: str) -> float:
        ...

    @overload
    def fun(a: int) -> float:
        ...

    def fun(a: str | int) -> float:
        raise NotImplementedError()

and defining ad-hoc TypeVar like YoYo = TypeVar("YoYo", str, int) but MyPy always complains, each time for a different reason.


Solution

  • Use a type variable:

    import typing
    
    T = typing.TypeVar("T")
    
    
    class Foo(typing.Generic[T]):
        def foo(self, a: T) -> float:
            raise NotImplementedError()
    
    
    class FooChild(Foo[str]):
        def foo(self, a: str) -> float:
            return 0.5
    
    
    x = FooChild().foo('foo')
    

    You could always restrict the type var as well,

    T = typing.TypeVar("T", float, str)
    

    If you wanted to restrict subclasses to using specific types.

    import typing
    
    T = typing.TypeVar("T", str, int)
    
    
    class Foo(typing.Generic[T]):
        def foo(self, a: T) -> float:
            raise NotImplementedError()
    
    
    class Bar(Foo[str]):
        def foo(self, a: str) -> float:
            return 0.5
    
    class Baz(Foo[list[int]]):
        def foo(self, a: list[int]) -> float:
            return 0.1
    

    Then mypy complains on the definition of Baz:

    jarrivillaga$ mypy test_mypy.py 
    test_mypy.py:15: error: Value of type variable "T" of "Foo" cannot be "list"  [type-var]
    Found 1 error in 1 file (checked 1 source file)