Search code examples
pythontypingmypy

Type subclass from inside the parent class


Suppose we have the following class:

from __future__ import annotations

class BaseSwallow:
    # Can't get the ref to `BaseSwallow` at runtime
    DerivedSwallow = NewType('DerivedSwallow', BaseSwallow) 

    def carry_with(self, swallow: DerivedSwallow):
        self.carry()
        swallow.carry()

    def carry(self):
        pass

class AfricanSwallow(BaseSwallow): pass

godOfSwallows = BaseSwallow()
africanSwallow = AfricanSwallow()

africanSwallow.carry_with(godOfSwallows)  # Should fail at static typing

I want to enforce that carry_with should only be called with instances of classes derived from BaseSwallow, so I use NewType to do that like the doc says.

However, NewType needs a reference to the base class object to work, and I don't have access to that at runtime. Before runtime, I have "access" to BaseSwallow thanks to the annotations module but it will still fail when running.

I'm aware that using an Abstract Base Class for BaseSwallow is the best thing to do here in most cases, but I can't do that for various reasons.

Any idea ?


Solution

  • I don't think there's a way to express "all subclasses of T excluding T" using type annotations. If you have a fixed set of subclasses you could use a Union type to capture them all, but that's probably not what you want. I think Sam's answer is your best bet: just use the BaseSwallow base class instead of crafting a complicated type to rule out the base class itself.

    Also, I think you misunderstood the usage for NewType. NewType is used to create an alias of a type that requires explicit conversion. For example:

    URL = NewType('URL', str)
    
    def download(url: URL): ...
    
    link_str = "https://..."  # inferred type is `str`
    link_url = URL(link_str)  # inferred type is `URL`
    download(link_str)  # type error
    download(link_url)  # correct
    

    Edit: If you don't mind a little bit of overhead, you can achieve this with an additional level of inheritance. Create a subtype of BaseSwallow (named Swallow for convenience), and have all the derived classes inherit Swallow instead of BaseSwallow. This way, you can annotate the carry_with method using the Swallow type:

    class BaseSwallow:
        def carry_with(self, swallow: 'Swallow'):  # string as forward reference
            self.carry()
            swallow.carry()
    
        def carry(self):
            pass
    
    class Swallow(BaseSwallow): pass  # dummy class to serve as base
    
    class AfricanSwallow(Swallow): pass
    
    godOfSwallows = BaseSwallow()
    africanSwallow = AfricanSwallow()
    
    africanSwallow.carry_with(godOfSwallows)  # mypy warns about incompatible types