Search code examples
pythonclasstypingcallablesubtype

How to make python typing recognize subclasses as valid types when it expects their parent class?


Here is a minimal example of what I need to do:

from typing import Callable, Any


class Data:
    pass


class SpecificData(Data):
    pass


class Event:
    pass


class SpecificEvent(Event):
    pass


def detect_specific_event(data: SpecificData, other_info: str) -> SpecificEvent:
    return SpecificEvent()


def run_detection(callback: Callable[[Data, Any], Event]) -> None:
    return


run_detection(detect_specific_event)

Now I get a warning:

Expected type '(Data, Any) -> Event', got '(data: SpecificData, other_info: str) -> SpecificEvent' instead 

To me it seems like this warning doesn't make sense, as SpecificData and SpecificEvent are subtypes of Data and Event respectively, so everything should be fine. Is there a way to make this work as I expect? My idea is to be able to then have something like:

class OtherSpecificData(Data):
    pass


class OtherSpecificEvent(Event):
    pass


def detect_other_event(data: OtherSpecificData, other_info: str) -> OtherSpecificEvent:
    return OtherSpecificEvent()

run_detection(detect_other_event)

so the run_detection function is as general as possible. Right now this gives the same warning as above.


Solution

  • Parameter sub-typing is opposite direction with return sub-typing.

    • Return value is assigned from callee to caller.
    • Parameter value is assigned from caller to callee.

    And assign value should be more specific than variable's expected type. For example:

    data: Data = SpecificData()  # okay
    data: SpecificData = Data()  # not okay
    

    So you should do:

    from typing import Callable, Any
    
    
    class Data:
        pass
    
    
    class SpecificData(Data):
        pass
    
    
    class Event:
        pass
    
    
    class SpecificEvent(Event):
        pass
    
    
    def detect_specific_event(data: Data, other_info: str) -> SpecificEvent:
        return SpecificEvent()
    
    
    def run_detection(callback: Callable[[SpecificData, Any], Event]) -> None:
        return
    
    
    run_detection(detect_specific_event)