Search code examples
pythonnumpyoverloading

Overloading Numpy functions?


I am hoping for some clarification on overloading Numpy universal functions in class methods.

To illustrate, here is a class myreal with an overloaded cos method. This overloaded method calls cos imported from the math module.

from math import cos

class myreal:
    def __init__(self,x):
        self.x = x
        
    def cos(self):
        return self.__class__(cos(self.x))

    def __str__(self):
        return self.x.__str__()
                
x = myreal(3.14)
y = myreal.cos(x)
print(x,y)

This works as expected and results in values

3.14 -0.9999987317275395

And, as expected, simply trying

z = cos(x)

results in an error TypeError: must be real number, not myreal, since outside of the myreal class, cos expects a float argument.

But surprisingly (and here is my question), if I now import cos from numpy, I can call cos(x) as a function, rather than as a method of the myreal class. In other words, this now works:

from numpy import cos
z = cos(x)

So it seems that the myreal.cos() method is now able to overload the Numpy global function cos(x). Is this "multipledispatch" behavior included by design?

Checking the type of the Numpy cos(x) reveals that it is of type `numpy.ufunc', which suggests an explanation involving Numpy universal functions.

Any clarification about what is going on here would be very interesting and helpful.


Solution

  • np.cos given a numeric dtype array (or anything that becomes that):

    In [246]: np.cos(np.array([1,2,np.pi]))
    Out[246]: array([ 0.54030231, -0.41614684, -1.        ])
    

    But if I give it an object dtype array, I get an error:

    In [248]: np.cos(np.array([1,2,np.pi,None]))
    ---------------------------------------------------------------------------
    AttributeError                            Traceback (most recent call last)
    AttributeError: 'int' object has no attribute 'cos'
    
    The above exception was the direct cause of the following exception:
    
    TypeError                                 Traceback (most recent call last)
    Input In [248], in <cell line: 1>()
    ----> 1 np.cos(np.array([1,2,np.pi,None]))
    
    TypeError: loop of ufunc does not support argument 0 of type int which has no callable cos method
    

    With the None element the array is object dtype:

    In [249]: np.array([1,2,np.pi,None])
    Out[249]: array([1, 2, 3.141592653589793, None], dtype=object)
    

    ufunc when given an object dtype array, iterates through the array and tries to use each element's "method". For something like np.add, it looks for the .__add__ method. For functions like cos (and exp and sqrt) it looks for a method of the same name. Usually this fails because most objects don't have a cos method. In your case, it didn't fail - because you defined a cos method.

    Try np.sin to see the error.

    I wouldn't call this overloading. It's just a quirk of how numpy handles object dtype arrays.