Search code examples
pythonpython-typingmypy

How to Bound TypeVar Correctly to Protocol?


I want to type annotate a simple sorting function that receives a list of any values that have either __lt__, __gt__, or both, but not mixed methods (i.e., all values should have the same comparison method) and returns a list containing the same elements it received in sorted order.

What I've done so far:

from typing import Any, Protocol

class SupportsLT(Protocol):
    def __lt__(self, other: Any, /) -> bool: ...

class SupportsGT(Protocol):
    def __gt__(self, other: Any, /) -> bool: ...

def quick_sort[T: SupportsLT | SupportsGT](seq: list[T]) -> list[T]:
    if len(seq) <= 1:
        return list(seq)
    pivot = seq[0]
    smaller: list[T] = []
    larger_or_equal: list[T] = []
    for i in range(1, len(seq)):
        item = seq[i]
        if item < pivot:
            smaller.append(item)
        else:
            larger_or_equal.append(item)
    return [*quick_sort(smaller), pivot, *quick_sort(larger_or_equal)]

Running mypy --strict gives the error: error: Unsupported left operand type for < (some union) [operator]. I think this is because there is no guarantee that the values do not have mixed protocols.

Another attempt is using constraints instead of bound:

def quick_sort[T: (SupportsLT, SupportsGT)](seq: list[T]) -> list[T]:
    if len(seq) <= 1:
        return list(seq)
    pivot = seq[0]
    smaller: list[T] = []
    larger_or_equal: list[T] = []
    for i in range(1, len(seq)):
        item = seq[i]
        if item < pivot:
            smaller.append(item)
        else:
            larger_or_equal.append(item)
    return [*quick_sort(smaller), pivot, *quick_sort(larger_or_equal)]

When I run mypy --strict, I get complaints about the type of the list returned by the return statement:

error: List item 0 has incompatible type "list[T]"; expected "SupportsLT"  [list-item]
error: List item 0 has incompatible type "list[T]"; expected "SupportsGT"  [list-item]
error: List item 2 has incompatible type "list[T]"; expected "SupportsLT"  [list-item]
error: List item 2 has incompatible type "list[T]"; expected "SupportsGT"  [list-item]

However, the second attempt has a significant issue. Running the following gives Revealed type is "builtins.list[SupportsLT]", which should be "builtins.list[builtins.int]" instead:

from typing import reveal_type

x = quick_sort([12, 3])
reveal_type(x)

python version: 3.13.0, mypy version: 1.13.0

UPDATE:

The reason I'm using both SupportsLT and SupportsGT as bound while my function only utilizes the less-than operator (i.e., <) is that when the left-hand operand lacks the __lt__ method, Python calls the __gt__ method of the right-hand operand and passes the left operand as an argument. Thus, values with only __gt__ should be considered valid as input to my function. Consider the following simple example:

from typing import Self

class HasGT:
    def __init__(self, value: int) -> None:
        self.value = value

    def __gt__(self, other: Self) -> bool:
        return self.value > other.value

print(HasGT(5) < HasGT(6))  # Prints True

Solution

  • Union bound

    Your first attempt is indeed unsafe. Let's see that:

    class HasGT:
        def __init__(self, value: int) -> None:
            self.value = value
    
        def __gt__(self, other: Self) -> bool:
            return self.value > other.value
    
    
    class HasLT:
        def __init__(self, value: int) -> None:
            self.value = value
    
        def __lt__(self, other: Self) -> bool:
            return self.value < other.value
    
    
    foo: list[HasLT | HasGT] = [HasLT(2), HasGT(3)]
    sorted_foo = quick_sort(foo)
    

    mypy accepts this part (the error points at your definition), but it fails at runtime:

    $ mypy s.py --strict
    s.py:17: error: Unsupported left operand type for < (some union)  [operator]
    Found 1 error in 1 file (checked 1 source file)
    
    $ python s.py
    Traceback (most recent call last):
      File "/tmp/temp/s.py", line 56, in <module>
        sorted_foo = quick_sort(foo)
                     ^^^^^^^^^^^^^^^
      File "/tmp/temp/s.py", line 17, in quick_sort
        if item < pivot:
           ^^^^^^^^^^^^
    TypeError: '<' not supported between instances of 'HasGT' and 'HasLT'
    

    TypeVar bound to some type T can be substituted with any T1 <= T. Where T is a union type, there's nothing wrong with T1 being the same union type, that's explicitly allowed. So such implementation is unsafe.

    Constrained typevar

    Your second snippet is actually safe. There's a mypy bug making it reject your function as-is (something shady happens during unpacking, I'll have a look later), but replacing the last line with

    return quick_sort(smaller) + [pivot] + quick_sort(larger_or_equal)
    

    fixes things. Such implementation passes mypy --strict, but is barely useful as you noticed: it's return type will be just that, a protocol with a single comparison method.

    Reasonable compromise

    I think that it's reasonable to say "my implementation is fine" and provide the best possible signature for callers. It's reasonable to assume that the input collection is homogeneous, so let's just make two overloads (and also avoid restricting the input to lists: it works for any sequence):

    from collections.abc import Sequence
    from typing import Any, Protocol, Self, overload
    
    # [snip] Supports{L,G}T and Has{L,G}T definitions here
    
    @overload
    def quick_sort_overloaded[T: SupportsLT](seq: Sequence[T]) -> list[T]: ...
    @overload
    def quick_sort_overloaded[T: SupportsGT](seq: Sequence[T]) -> list[T]: ...
    def quick_sort_overloaded[T: SupportsLT | SupportsGT](seq: Sequence[T]) -> list[T]:
        if len(seq) <= 1:
            return list(seq)
        pivot = seq[0]
        smaller: list[T] = []
        larger_or_equal: list[T] = []
        for i in range(1, len(seq)):
            item = seq[i]
            if item < pivot:  # type: ignore[operator]
                smaller.append(item)
            else:
                larger_or_equal.append(item)
        return [
            *quick_sort_overloaded(smaller),  # type: ignore[type-var]
            pivot, 
            *quick_sort_overloaded(larger_or_equal)  # type: ignore[type-var]
        ]
    

    And now

    reveal_type(quick_sort_overloaded([HasLT(2), HasLT(3)])) # N: Revealed type is "builtins.list[__main__.HasLT]"
    reveal_type(quick_sort_overloaded([HasGT(2), HasGT(3)]))  # N: Revealed type is "builtins.list[__main__.HasGT]"
    
    foo: list[HasLT | HasGT] = [HasLT(2), HasGT(3)]
    try:
        quick_sort_overloaded(foo)  # E: Value of type variable "T" of "quick_sort_overloaded" cannot be "HasLT | HasGT"  [type-var]
    except TypeError:
        print("`quick_sort_overloaded` failed as warned")
    

    Here's a playground to compare all those solutions.