Search code examples
pythonoperator-overloading

Why `__rsub__` subtract in reverse?


I have custom class that inherites from list. CustomList instances support subtraction between themselves and with regular lists. The corresponding elements of the CustomLists/lists are subtracted. If one of the CustomLists/lists is shorter, the missing elements are considered as zeros.

from typing import Iterable, List


class CustomList(list):
    def __init__(self, *args: Iterable[int | float]):
        if args == ():
            super().__init__(*args)
        elif all(isinstance(a, (int, float)) for a in args[0]):
            super().__init__(*args)
        else:
            raise TypeError("CustomList instance can contain int and float datatypes only")

    def __sub__(self, other: List[int | float] | "CustomList") -> "CustomList":
        if isinstance(other, (self.__class__, list)):
            result = CustomList()
            for i in range(max(len(self), len(other))):
                if i < len(self) and i < len(other):
                    result.append(self[i] - other[i])
                elif i < len(self):
                    result.append(self[i])
                else:
                    result.append(other[i])
            return result

    def __rsub__(self, other):
        return self.__sub__(other)

__sub__ works as expected, but when first operand is list (and __rsub__ is called) - subsctraction works in reverse. Instead of operand1 - operand2 it does operand2 - operand1.
For example:

[1, 2, 3] - CustomList([2, 4, 6]) sholud be CustomList([-1, -2, -3]) but I got CustomList([1, 2, 3]) instead. Why does this happen?


Solution

  • __rsub__ is called when you do list - CustomList, let's say L - myList. when you define __rsub__ the way you did, you are returning the result of myList - L. But subtraction is not commutative (a - b != b - a).

    My recommended solution is to overload the negation operator (__neg__) to define what - myList is. and then modify __rsub__ to return - self.__sub__(other).

    Here is the full solution:

    from typing import Iterable, List
    
    
    class CustomList(list):
        def __init__(self, *args: Iterable[int | float]):
            if args == ():
                super().__init__(*args)
            elif all(isinstance(a, (int, float)) for a in args[0]):
                super().__init__(*args)
            else:
                raise TypeError("CustomList instance can contain int and float datatypes only")
    
        def __sub__(self, other: List[int | float] | "CustomList") -> "CustomList":
            if isinstance(other, (self.__class__, list)):
                result = CustomList()
                for i in range(max(len(self), len(other))):
                    if i < len(self) and i < len(other):
                        result.append(self[i] - other[i])
                    elif i < len(self):
                        result.append(self[i])
                    else:
                        result.append(other[i])
                return result
    
        def __neg__(self):
            result = CustomList()
            for val in self:
                result.append(-val)
            return result
    
        def __rsub__(self, other):
            return - self.__sub__(other)
    
    >>> [1, 2, 3] - CustomList([2, 4, 6])
    [-1, -2, -3]