Search code examples
python-3.xsubclassnumpy-ndarraydtype

How to write a subclass of numpy.ndarray which only takes complex values?


I would like to create a subclass of numpy.ndarray which is an array of complex number. To that purpose, I'm trying to make the constructor of my sublass such that it returns an array of (0+0j). I'm unsuccessful for the moment... Here is my code so far :

import numpy as np


class ComplexArray(np.ndarray):
    def __init__(self, args):
        np.ndarray.__init__(args, dtype=complex)
        self.fill(0)


a = ComplexArray(3)
a[0] = 1j

When I run the above code, I get the error TypeError: can't convert complex to float.

I specify that the reason why I want to create such a subclass is that I want to implement several methods in it afterwards.

Thank you in advance for your advice !


Solution

  • I have found a solution :

    import numpy as np
    
    
    class ComplexArray(np.ndarray):
        def __new__(cls, n):
            ret = np.zeros(n, dtype=complex)
            return ret.view(cls)