Search code examples
pythoncython

Correct override of __rmul__ in cython Extensions


How I must override rmul in cython?

For example, this works perfectly in python

class PyFoo:
  def __init__(self):
      self.a = 4
  def __mul__(self,b):  return b*self.a
  def __rmul__(self,b): return self*b

Pynew = PyFoo()

print "   Python   "
print Pynew*3 # I got 12
print 3*Pynew # I got 12

But, if I implement the same in Cython doesn't work,

cclass.pyx

cimport cython

cdef class Foo:
  cdef public int a
  def __init__(self):
      self.a = 4
  def __mul__(self,b):  return b*self.a
  def __rmul__(self,b): return self*b

test.py

import cclass as c
Cnew = c.Foo()
print "   Cython   "
print Cnew*3 # This works, I got 12
print 3*Cnew # This doesn't

I got this error

    Traceback (most recent call last):
  File "test.py", line 22, in <module>
    print 3*Cnew
  File "cclass.pyx", line 8, in cclass.Foo.__mul__ (cclass.c:763)
    def __mul__(self,b):  return b*self.a
AttributeError: 'int' object has no attribute 'a'

I don't understand what is the problem with use the same implementation of rmul in Cython.


Solution

  • This is a case of not reading the documentation. In the Special Methods of Extension Types user guide you will find the following:

    Arithmetic operator methods, such as __add__(), behave differently from their Python counterparts. There are no separate “reversed” versions of these methods (__radd__(), etc.) Instead, if the first operand cannot perform the operation, the same method of the second operand is called, with the operands in the same order.

    This means that you can’t rely on the first parameter of these methods being “self” or being the right type, and you should test the types of both operands before deciding what to do. If you can’t handle the combination of types you’ve been given, you should return NotImplemented.

    So you should really be doing some type checking, atleast in the following way:

    cdef class Foo:
        cdef public int a
    
        def __init__(self):
            self.a = 4
    
        def __mul__(first, other):
            if isinstance(first, Foo):
                return first.a * other
            elif isinstance(first, int):
                return first * other.a
            else:
                return NotImplemented
    

    This solution is overly optimistic concerning the use of the Foo class, you may need to check the type of other aswell, and/or check for a more generic number type.