Search code examples
pythonpython-typingmypy

Inheritance and polymorphism in Python when using mypy not working


I'm looking to do some standard polymorphism with mypy, which I've never used before and it's so far not intuitive.

Base-class

class ContentPullOptions:
    pass


class Tool(Protocol):
    async def pull_content(self, opts: ContentPullOptions) -> str | Dict[str, Any]: ...

Sub-class

class GoogleSearchOptions(ContentPullOptions):
    query: str
    sites: List[str]


class GoogleSearchTool(Tool):
    async def pull_content(
        self,
        opts: GoogleSearchOptions,
    ) -> str | Dict[str, Any]:

Fails with:

error: Argument 1 of "pull_content" is incompatible with supertype "Tool"; supertype defines the argument type as "ContentPullOptions" 

What is the most maintainable and clean way to do basic inheritance with type checking in mypy like this?

I tried custom types, casting, etc. But everything felt a bit messy and unclear.

Solution

Maybe this still violates the Liskov Principle

from typing import TypeVar, Protocol, Dict, Any, List, Callable

# Define T as contravariant
T_contra = TypeVar('T_contra', bound=ContentPullOptions, contravariant=True)

class ContentPullOptions:
    pass

class Tool(Protocol[T_contra]):
    async def pull_content(self, opts: T_contra) -> str | Dict[str, Any]: ...

class GoogleSearchOptions(ContentPullOptions):
    query: str
    sites: List[str]

class GoogleSearchTool:
    async def pull_content(
        self,
        opts: GoogleSearchOptions,
    ) -> str | Dict[str, Any]:
        # Implementation here
        pass

Solution

  • Solution Maybe this still violates the Liskov Principle

    from typing import TypeVar, Protocol, Dict, Any, List, Callable
    
    # Define T as contravariant
    T_contra = TypeVar('T_contra', bound=ContentPullOptions, contravariant=True)
    
    class ContentPullOptions:
        pass
    
    class Tool(Protocol[T_contra]):
        async def pull_content(self, opts: T_contra) -> str | Dict[str, Any]: ...
    
    class GoogleSearchOptions(ContentPullOptions):
        query: str
        sites: List[str]
    
    class GoogleSearchTool:
        async def pull_content(
            self,
            opts: GoogleSearchOptions,
        ) -> str | Dict[str, Any]:
            # Implementation here
            pass