Search code examples
pythongenericstype-inferencemypypython-typing

Type casting of Generics


I'm using mypy and ran into unexpected behaviour. Mypy incorrectly infer type on an expected type

from typing import Generic, TypeVar, Callable, reveal_type

S1 = TypeVar('S1')
F1 = TypeVar('F1')
I = TypeVar('I')


class Node(Generic[I, S1, F1]):
    def __init__(self, callback: Callable[[I], S1 | F1]):
        self.callback: Callable[[I], S1 | F1] = callback


class Succ1:
    ...


class Fail1:
    ...


def func1(_: str) -> Succ1 | Fail1:
    return Succ1()


n1 = Node(func1)
res1 = n1.callback("str")
reveal_type(n1)
% mypy isolated_example2.py
isolated_example2.py:25: error: Need type annotation for "n1"  [var-annotated]
isolated_example2.py:25: error: Argument 1 to "Node" has incompatible type "Callable[[str], Succ1 | Fail1]"; expected "Callable[[str], Never]"  [arg-type]
isolated_example2.py:27: note: Revealed type is "isolated_example2.Node[builtins.str, Any, Any]"
Found 2 errors in 1 file (checked 1 source file)

I do not expect type Callable[[str], Never] and see no reason to think so. What could be the problem?

mypy==1.9.0
Python 3.12.3

It's a part of a bigger problem but I'm trying to split it into isolated chunks to understand process better

UPD: Looks like a problem in Unioning TypeVar. This is minimal example to reproduce the error

from typing import TypeVar

S1 = TypeVar('S1')
F1 = TypeVar('F1')


def func(arg: S1 | F1):
    return arg


func(None)

Solution

  • Union is clearly commutative, right? You cannot just randomly split a union into two ordered parts - what if you have def foo(self) -> S1 in the class?

    class Node(Generic[I, S1, F1]):
        def __init__(self, callback: Callable[[I], S1 | F1]):
            self.callback: Callable[[I], S1 | F1] = callback
        
        def foo(self) -> S1:
            # Huh? Succ1? Fail1? Succ1 | Fail1? object? Something else?
            raise NotImplementedError
    

    Should S1 resolve to Succ1 or Fail1 here? Why?

    So, you need some way to tell the type checker how to split the variables correctly. You could use bound type variables for that:

    from typing import Generic, TypeVar, Callable, Literal, Protocol, reveal_type
    
    class SuccBase(Protocol):
        success: Literal[True]
    class FailBase(Protocol):
        success: Literal[False]
    
    S1 = TypeVar('S1', bound=SuccBase)
    F1 = TypeVar('F1', bound=FailBase)
    I = TypeVar('I')
    
    class Node(Generic[I, S1, F1]):
        def __init__(self, callback: Callable[[I], S1 | F1]):
            self.callback: Callable[[I], S1 | F1] = callback
    
    
    # Pyright correctly accepts without protocol inheritance
    # For mypy need to inherit protocol explicitly (will report a bug tomorrow?)
    class Succ1(SuccBase):
        success: Literal[True] = True
    class Fail1(FailBase):
        success: Literal[False] = False
    class Fail2(FailBase):
        success: Literal[False] = False
    
    
    def func1(_: str) -> Succ1 | Fail1 | Fail2:
        return Succ1()
    
    
    n1 = Node(func1)
    res1 = n1.callback("str")
    reveal_type(n1)
    

    mypy playground, pyright playground

    I'm using "discriminated union" kind of solution with Protocol above, but you can also require explicit inheritance (ABC or just a plain FailBase class with no attributes), or use different attributes that make sense in your context, or tweak something else - the key point is that you need to clearly indicate what should go into S1 and what - into F1.

    If you are in full control of API design, consider not doing this and returning a proper Result-like type (e.g. https://github.com/rustedpy/result, or roll out your own - it's trivial). This will allow making never-failing functions (-> Result[Succ1, Never]) more explicit, for example. Or just get rid of Fail1 entirely, do not return exceptions, raise them - python is built around exceptions, and performance penalty is quite low unless you're in a tight loop.

    Side note: you see Never there just because the inference is impossible. It's a side effect; that will go away if you annotate n1 explicitly.