Search code examples
pythonpython-typingmypy

Comparable types with mypy


I'm trying to create a generic class to express that a value has lower and upper bounds, and to enforce those bounds.

from typing import Any, Optional, TypeVar

T = TypeVar("T")

class Bounded(object):
    def __init__(self, minValue: T, maxValue: T) -> None:
        assert minValue <= maxValue
        self.__minValue = minValue
        self.__maxValue = maxValue

However, mypy complains that:

error: Unsupported left operand type for <= ("T")

Apparently typing module doesn't allow me to express this (although it looks like adding Comparable might happen in the future).

I think it would be enough to check that object has __eq__ and __lt__ methods (for my use case at least). Is there any way to currently express this requirement in Python so that Mypy would understand it?


Solution

  • After a bit more research, I found a solution: Protocols. Since they are not fully stable (yet of Python 3.6), they have to be imported from the typing_extensions modules.

    import typing
    from typing import Any
    from typing_extensions import Protocol
    from abc import abstractmethod
    
    C = typing.TypeVar("C", bound="Comparable")
    
    class Comparable(Protocol):
        @abstractmethod
        def __eq__(self, other: Any) -> bool:
            pass
    
        @abstractmethod
        def __lt__(self: C, other: C) -> bool:
            pass
    
        def __gt__(self: C, other: C) -> bool:
            return (not self < other) and self != other
    
        def __le__(self: C, other: C) -> bool:
            return self < other or self == other
    
        def __ge__(self: C, other: C) -> bool:
            return (not self < other)
    

    Now we can define our type as:

    C = typing.TypeVar("C", bound=Comparable)
    
    class Bounded(object):
        def __init__(self, minValue: C, maxValue: C) -> None:
            assert minValue <= maxValue
            self.__minValue = minValue
            self.__maxValue = maxValue
    

    And Mypy is happy:

    from functools import total_ordering
    
    @total_ordering
    class Test(object):
        def __init__(self, value):
            self.value = value
        def __eq__(self, other):
            return self.value == other.value
        def __lt__(self, other):
            return self.value < other.value
    
    FBounded(Test(1), Test(10))
    FBounded(1, 10)