Search code examples
pythonnumpynumpy-ndarray

Have numpy.concatenate return proper subclass rather than plain ndarray


I have a numpy array subclass, and I'd like to be able to concatenate them.

import numpy as np

class BreakfastArray(np.ndarray):
    def __new__(cls, n=1):
        dtypes=[("waffles", int), ("eggs", int)]
        obj = np.zeros(n, dtype=dtypes).view(cls)
        return obj
        
b1 = BreakfastArray(n=1)
b2 = BreakfastArray(n=2)
con_b1b2 = np.concatenate([b1, b2])

print(b1.__class__, con_b1b2.__class__)

this outputs <class '__main__.BreakfastArray'> <class 'numpy.ndarray'>, but I'd like the concatenated array to also be a BreakfastArray class. It looks like I probably need to add a __array_finalize__ method, but I can't figure out the right way to do it.


Solution

  • Expanding simon's solution, this is what I settled on so other numpy functions fall-back to standard ndarray (so, numpy.unique(b2["waffles"]) works as expected). Also a slight change to concatenate so it will work for any subclasses as well.

    import numpy as np
    
    HANDLED_FUNCTIONS = {}
    
    class BreakfastArray(np.ndarray):
        def __new__(cls, *args, n=1, **kwargs):
            dtypes=[("waffles", int), ("eggs", int)]
            obj = np.zeros(n, dtype=dtypes).view(cls)
            return obj
    
        def __array_function__(self, func, types, args, kwargs):
            # If we want "standard numpy behavior",
            # convert any BreakfastArray to ndarray views
            if func not in HANDLED_FUNCTIONS:
                new_args = []
                for arg in args:
                    if issubclass(arg.__class__, BreakfastArray):
                        new_args.append(arg.view(np.ndarray))
                    else:
                        new_args.append(arg)
                return func(*new_args, **kwargs)
            if not all(issubclass(t, BreakfastArray) for t in types):
                return NotImplemented
            return HANDLED_FUNCTIONS[func](*args, **kwargs)
    
    def implements(numpy_function):
        def decorator(func):
            HANDLED_FUNCTIONS[numpy_function] = func
            return func
        return decorator
    
    @implements(np.concatenate)
    def concatenate(arrays):
        result = arrays[0].__class__(n=sum(len(a) for a in arrays))
        return np.concatenate([np.asarray(a) for a in arrays], out=result)