Search code examples
pythonpython-3.xinheritancemethodsoverriding

Python method overriding - more specific arguments in derived than base class


Let's say I want to create an abstract base class called Document. I want the type checker to guarantee that all its subclasses implement a class method called from_paragraphs, which constructs a document from a sequence of Paragraph objects. However, a LegalDocument should only be constructable from LegalParagraph objects, and an AcademicDocument - only from AcademicParagraph objects.

My instinct is to do it like so:

from abc import ABC, abstractmethod
from typing import Sequence


class Document(ABC):
    @classmethod
    @abstractmethod
    def from_paragraphs(cls, paragraphs: Sequence["Paragraph"]):
        pass


class LegalDocument(Document):
    @classmethod
    def from_paragraphs(cls, paragraphs: Sequence["LegalParagraph"]):
        return  # some logic here...


class AcademicDocument(Document):
    @classmethod
    def from_paragraphs(cls, paragraphs: Sequence["AcademicParagraph"]):
        return  # some logic here...


class Paragraph:
    text: str


class LegalParagraph(Paragraph):
    pass


class AcademicParagraph(Paragraph):
    pass

However, Pyright complains about this because from_paragraphs on the derived classes violates the Liskov substitution principle. How do I make sure that each derived class implements from_paragraphs for some kind of Paragraph?


Solution

  • Turns out this can be solved using generics:

    from abc import ABC, abstractmethod
    from typing import Generic, Sequence, TypeVar
    
    ParagraphType = TypeVar("ParagraphType", bound="Paragraph")
    
    
    class Document(ABC, Generic[ParagraphType]):
        @classmethod
        @abstractmethod
        def from_paragraphs(cls, paragraphs: Sequence[ParagraphType]):
            pass
    
    
    class LegalDocument(Document["LegalParagraph"]):
        @classmethod
        def from_paragraphs(cls, paragraphs):
            return  # some logic here...
    
    
    class AcademicDocument(Document["AcademicParagraph"]):
        @classmethod
        def from_paragraphs(cls, paragraphs):
            return  # some logic here...
    
    
    class Paragraph:
        text: str
    
    
    class LegalParagraph(Paragraph):
        pass
    
    
    class AcademicParagraph(Paragraph):
        pass
    

    Saying bound="Paragraph" guarantees that the ParagraphType represents a (subclass of) Paragraph, but the derived classes are not expected to implement from_paragraphs for all paragraph types, just for the one they choose. The type checker also automatically figures out the type of the argument paragraphs for LegalDocument.from_paragraphs, saving me some work :)