Search code examples
pythoninheritancenumpyattributeerrordefault-arguments

numpy subclass will not accept arguments to __new__ from pythonically inheriting class


I've created a subclass of ndarray called "Parray" which takes two arguments: p, and dimensionality. It works fine on its own. Now, I want to create a class called SirPlotsAlot, which inherits Parray without all the fancy new and array_finalize etc.

import numpy as np

class Parray(np.ndarray):
    def __new__(self, p = Parameters(), dimensionality = 2):

        print "Initializing Parray with initial dimensionality %s..." % dimensionality

        self.p = p # store the parameters

        if dimensionality == 2:
            shape = (p.nx, p.ny)
            self.pshape = shape
        elif dimensionality == 3:
            shape=(p.nx, p.ny, p.nx)
            self.pshape = shape
        else:
            raise NotImplementedError, "dimensionality must be 2 or 3"

        # ...Set other variables (ellided)

        subarr = np.ndarray.__new__(self, shape, dtype, buffer, offset, strides, order)
        subarr[::] = np.zeros(self.pshape) # initialize to zero
        return subarr
...

class SirPlotsAlot(Parray):
    def __init__(self, p = Parameters(), dimensions = 3):
        super(SirPlotsAlot, self).__new__(p, dimensions)     # (1)

Objects in my program share sets of parameters by passing an object p = Parameters() back and forth.

Now, when I type (the file is auxiliary.py):

import auxiliary
from parameters import Parameters
p = Parameters()
s = auxiliary.SirPlotsAlot(p, 3)

expecting to get a nice "Initializing Parray with initial dimensionality 3", I get "2", instead. BUT if I type:

import auxiliary
s = auxiliary.SirPlotsAlot()

I get

---> 67             shape = (p.nx, p.ny)
"AttributeError: 'int' object has no attribute 'nx'"

It thinks "p" is an int, which it is not. I can get lots of weird seemingly unrelated errors if I play around with it. The int it thinks it is is "2". I'm completely lost.

I've tried with and without the # (1) comment (the super call).

Other errors from playing around include "AttributeError: 'list' object has no attribute 'p'", "TypeError: new() takes exactly 2 arguments (1 given)", "ValueError: need more than 0 values to unpack" (I replaced new's arguments with *args, something I don't understand very well).


Solution

  • It's been ten years and I long left the project, but I resolved this issue by creating helper functions to create new classes and set them up. In the code example below, see the definitions at the bottom of the file. I imported and used those.

    Props to Matthew Schinckel to pointing out that __new__ should have already been called by the time __init__ runs, and to everyone else for their thoughts.

    # -*- coding: utf-8 -*-
    """
    Era's Plotting Functionality. This module exports SirPlotsAlot and company:
        
    class SirPlotsAlot: Array with 2D, 3D, animated plotting capability, and a pyrism Parameteres object.
    
    def NewSirPlotsAlot(p, dimensionality): returns SirPlotsAlot, but doesn't need explicit parameters
    
    def returns_SirPlotsAlot: decorator force ndarray-returning function to return SirPlotsAlot instead.
    
    Created on Thu Jul 12 18:46:15 2012
    @author: Era
    """
    
    # SirPlotsAlot
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.animation as animation
    import matplotlib.pyplot as pyplot
    import numpy as np
    import scipy as scipy
    import logging
    
    lprint = logging.getLogger('pyrism')
    
    
    
    class SirPlotsAlot(np.ndarray):
        """
        An array with 2D, 3D, animated plotting capability, and a pyrism Parameters object.
    
        Inherits: numpy's ndarray
    
        Input:    
            dimensionality: An int. The dimensions of the ndarray. Can be changed later.
            p:  A Parameters object.
    
        """
    
        # class variables
        currentSlice = 0    # for _updateSlice and animated plots
    
        
        def __new__(cls, shape):
            """
            Creates a new SirPlotsAlot for us to use.
    
            SirPlotsAlot inherits ndarray. ndarray is written in C, and needs an extra\
            method called __new__ to help it.
            
            Args:
                shape: A tuple of ints. The shape of the underlying ndarray.
                
            Returns:
                an ndarray
                
            Author / Date:
                Erasmus Alcarin   /   January 23rd, 2012
                Erasmus Alcarin   /      July 13th, 2012
                
            """
    
           
            ### Specify the exact parameters of the array this class implements
            dtype=float         # dtype: data type. Optional
                                    # Any object can be interpreted as
                                    # a numpy datatype
            buffer=None         # buffer: object exposing buffer interface. Optional
                                    # Used to fill array with data
            offset=0            # offset: int. Optional.
                                    # offset of array in data buffer
            strides=None        # strides : tuple of ints. Optional
                                    # Strides of data in memory
            order=None          # order : {'C', 'F'}. Optional
                                    # Row-major or column-major order.
            
            # Instantiate new ndarray (this class). Temporarily called sub_array.
            subarr = np.ndarray.__new__(cls,   # cls is crucial.
                                                # it creates a ndarray that
                                                # is of type THIS CLASS
                                                # instead of type ndarray
                    shape, dtype, buffer, offset, strides, order)
            
            # Return the successfully created instance for this class to use!
            return subarr
    
    
        def __init__(self, shape):
            """Says hello!
    
            Args:
                shape: A tuple of ints. The shape of the underlying ndarray.
    
            Returns:
                None
    
            Author:
                Erasmus Alcarin   /   January 23rd, 2012
                Erasmus Alcarin   /      July 13th, 2012
    
            """
            lprint.debug("Ah, kind sir! Thy bidding be done!")
    
    
        def __array_finalize__(self, obj):
            """Allow inheritance of ndarray's unary(?) operations.
    
            Purpose:  ndarray has a lot of functions which let you interact
                      with it (all its awesome features, specifically views
                      and so-called "new-from-template": that is, slices).
                      This function tells python that our class also gets
                      to use all of those nifty "unary" features!
    
            Args:
                obj: Another object. For example, this function is called if we type
                    myArr = myIntensityMap[1:]
                    (myArr is obj, and myIntensityMap is self)
    
            Returns:
                None
    
            Author / Date:
                Erasmus Alcarin   /   January 23rd, 2012
    
            """
            if obj is None: return
            
    
        def __array_wrap__(self, out_arr, context=None):
            """Allow inheritance of ndarray's binary(?) operations.
    
            Purpose:  ndarray has a lot of functions which let you interact
                      with it (all its awesome features, specifically array
                      adding, multiplying, etc.). This function tells python
                      to use all of those nifty "binary" features!
    
            Args:
                out_arr: What is returned in the operation which is being
                    performed.
                context: A parameter which _array_wrap__ is specified to take. (optional)
                    If you know, update me!
    
            Returns:
                See ndarray.__array_wrap___()
    
            Author / Date:
                Erasmus Alcarin   /   January 23rd, 2012
    
            """
            # Call ndarray's __array_wrap__ method.
            return np.ndarray.__array_wrap__(self, out_arr, context).view(type(self))
    
    
        def _enforceXD(self, X):
            """
            a helper function returning true if this
            SirPlotsAlot has dimensionality X, otherwise
            raising a ValueError.
    
            Args:
                X: An int. The underlying ndarray dimensionality being tested for.
    
            Returns:
                True if this array has dimensionality X.
                ValueError is raised otherwise.
    
            """
            if self.shape.__len__() == X:
                return True
            else:
                raise ValueError, "A %sD array was required. A %sD array was supplied." % (X, self.shape.__len__())
    
    
        def _checkXD(self, X):
            """
            a helper function returning true if this
            SirPlotsAlot has dimensionality X, otherwise
            raising a ValueError.
    
            Args:
                X: An int. The underlying ndarray dimensionality being tested for.
    
            Returns:
                True if this array has dimensionality X.
                False otherwise.
    
            """
            if self.shape.__len__() == X:
                return True
            else:
                return False
    
    
        def _checkLabelInfo(self, label = None):
            """
            a helper utility function to decide which
            of the accepted formats for plot label the user
            has specified.
    
            Args:
                label: The user's input (Valid formats are String, Tuple)
    
            Returns:
                Nothing
            
            """
            if len(label) >= 1 and type(label[0]) == str:
                pyplot.title(label[0])
            if len(label) >= 2 and type(label[1]) == str:
                pyplot.xlabel(label[1])
            if len(label) >= 3 and type(label[2]) == str:
                pyplot.ylabel(label[2])
            if len(label) >= 4 and type(label[3]) == str:
                pyplot.zlabel(label[3])
        
        
        def _add_labels(self, label = None, caller_label = 'none'):
            """
            A utility function to quickly add labels to
            any of the graphing utilities embedded in
            SirPlotsAlot.
    
            Args:
                label: (Str, Tuple) The label being supplied by the user.
                caller_label: A string. Each plotting function has its own
                    axes to label. This identifies the plotting function.
    
            Returns:
                Nothing
    
            """
            if label == None:
                raise ValueError, "_add_labels violated"
            else:
                lprint.debug("going on to labelling")
    
            if type(label) == str:
                pyplot.title(label)
            elif type(label) == tuple:
                self._checkLabelInfo(label)
            
            elif hasattr(self, 'caller_label'):
                label = getattr(self, caller_label)
                
                if type(label) == str:
                    pyplot.title(label)
                elif type(label) == tuple:
                    self._checkLabelInfo(label)
            
            else:
                print getattr(self, caller_label)
                raise ValueError, "_add_labels requires string or tuple of strings"
                
    
        def _updateSlice(self):
            """
            a helper function for animate2D(), this controls the
            progression (speed, sampling) of the animation by
            returning the next image to be presented in the animation.
    
            Args:
                None
    
            Returns:
                2D slice of this array.
    
            """
            if self._enforceXD(3):
                self.currentSlice += 1
            
                return self[self.currentSlice]
    
    
        def plot1D(self, label = None):
            """
            Plot 1-axis SirPlotsAlot in 2D, plotting array contents as y (up).
    
            Args:
                label: String or tuple labelling the plot.
    
            Returns:
                Nothing
    
            """
            if self._enforceXD(1):
                pyplot.figure()
                
                # self._add_labels(label, 'plot1D_label')
                
                pyplot.plot(self)
                 
                pyplot.show()
    
    
        def plot2D(self, label = None):
            """
            Plot 2-axis SirPlotsAlot in 2D, plotting array contents as color.
    
            Args:
                label: String or tuple labelling the plot.
    
            Returns:
                Nothing
    
            """
            if self._enforceXD(2):
                # Do not produce huge output
                #lprint.debug("We're plotting this up:\n%s" % self)
                lprint.debug("We're plotting you some goodies!")
    
                fig = pyplot.figure()
                if type(label) == str:      # if label is supplied, apply it.
                    pyplot.title(label)
                elif hasattr(self, 'plot2D_label'):
                    pyplot.title(self.plot2D_label)
    
                plot = pyplot.imshow(self)
                
                fig.colorbar(plot)
                #colorbar.ax.set_yticklabels(["%.2f" % self.min(), '0', "%.2f" % self.max()])
    
                pyplot.gca().invert_yaxis()
    
                pyplot.xlabel('x')
                pyplot.ylabel('y')
    
                pyplot.show()
    
    
        def save_plot2D(self, file = None, label = None, cbar_ticks = None):
            """
            Saves plot of 2-axis SirPlotsAlot in 2D, plotting array contents as color,
            in .png format.
    
            Args:
                file: A string. The filename to save to. Default: ``output``
                label: String or tuple labelling the plot.
                cbar_ticks: Colorbar ticks for plot. Default: auto.
    
            Returns:
                Nothing
    
            """
            if self._enforceXD(2):
                # Do not produce huge output
                #lprint.debug("We're plotting this up:\n%s" % self)
                lprint.debug("We're plotting you some goodies!")
    
                fig = pyplot.figure()
                if label is not None:
                    self._add_labels(label, 'plot2D_label')
    
                plot = pyplot.imshow(self)
                
                if cbar_ticks == None:
                    fig.colorbar(plot)
                else:
                    cbar = fig.colorbar(plot, ticks=cbar_ticks) # Numbers
                    cbar.ax.set_yticklabels(map(str, cbar_ticks))   # Label
    
                pyplot.gca().invert_yaxis()
    
                if file == None:
                    file = 'output'
                pyplot.savefig(file)
    
    
        # nice!
        def plot3D(self, label = None):
            """
            Plots 2-axis SirPlotsAlot in 3D, plotting array contents as 3rd dimension (up).
    
            Args:
                label: String or tuple labelling the plot.
    
            Returns:
                Nothing
    
            """
            if self._enforceXD(2):
                # Do not produce huge output
                #lprint.debug("We're plotting this up:\n%s" % self)
                lprint.debug("We're plotting you some goodies!")
    
                # make grid from min to max with interval nx
                x = scipy.linspace(0, self.shape[1], self.shape[1])
                y = scipy.linspace(0, self.shape[0], self.shape[0])
                [x, y] = scipy.meshgrid(x, y)       # this is the same as make_2d
    
                fig = pyplot.figure()
                if type(label) == str:      # if label is supplied, apply it.
                    pyplot.title(label)
                elif hasattr(self, 'plot3D_label'):
                    pyplot.title(self.plot3D_label)
    
                ax = Axes3D(fig)                    # make a 3D axis
                ax.plot_surface(x, y, self)
    
                pyplot.xlabel('x')
                pyplot.ylabel('y')
    
                pyplot.show()
    
    
        def plot3D_2(self, label = None):
            """
            Plots 2-axis SirPlotsAlot in 3D, plotting array contents as 3rd dimension (up),
            with contours projected onto each 2D cross-section of the 3D plot.
    
            Args:
                label: String or tuple labelling the plot.
    
            Returns:
                Nothing
    
            """
            if self._enforceXD(2):
                # Do not produce huge output
                #lprint.debug("We're plotting this up:\n%s" % self)
                lprint.debug("We're plotting you some goodies!")
    
                # make grid from min to max with interval nx
                x = scipy.linspace(0, self.shape[1], self.shape[1])
                y = scipy.linspace(0, self.shape[0], self.shape[0])
                [x, y] = scipy.meshgrid(x, y)
    
                fig = pyplot.figure()
                if type(label) == str:
                    pyplot.title(label)
                elif hasattr(self, 'plot3D_2_label'):
                    pyplot.title(self.plot3D_2_label)
    
                ax = fig.gca(projection='3d')
    
                ax.plot_surface(x, y, self, rstride=8, cstride=8, alpha=0.3)
                ax.contour(x, y, self, zdir='z', offset=self.min())
                ax.contour(x, y, self, zdir='x', offset=0)
                ax.contour(x, y, self, zdir='y', offset=self.shape[0])
    
                ax.set_xlabel('x')
                ax.set_xlim(0, self.shape[1])
                ax.set_ylabel('y')
                ax.set_ylim(0, self.shape[0])
                ax.set_zlabel('z')
                ax.set_zlim(self.min(), self.max())
    
                pyplot.show()
    
            # the following is probably deprecated code for the above.
            '''if self._enforceXD(2):
                print "We're plotting this up:\n%s" % self
    
                # make grid from min to max with interval nx
                x = y = scipy.linspace(self.min(), self.max(), self.shape[0])
                [x, y] = scipy.meshgrid(x, y)       # this is the same as make_2d
    
                fig = pyplot.figure()
                if type(label) == str:      # if label is supplied, apply it.
                    pyplot.title(label)
                elif hasattr(self, 'plot3D_label'):
                    pyplot.title(self.plot3D_label)
    
                ax = Axes3D(fig)                    # make a 3D axis
                ax.plot_surface(x, y, self)
    
                pyplot.show()
            '''
    
    
        def aniPlot2D(self):
            """
            Generate successive 2D color plots using color for the data. Then play these
            plots in series, creating an animation. Requires 3D SirPlotsAlot.
    
            Args:
                None
    
            Returns:
                Nothing
    
            """
            self.tplot = 0
    
            fig = pyplot.figure()
    
            #x = np.arange(0, self.shape[1])
            #y = np.arange(0, self.shape[0]).reshape(-1,1)
    
            ims = []
            imsappend = ims.append      # optimization
            for t in np.arange(self.shape[1]):
                imsappend((pyplot.imshow(self[t]),))
    
            animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=3000, blit=True)
    
            pyplot.show()
    
    
    
    def NewSirPlotsAlot(shape = (512, 512)):
        """
        Returns instance of SirPlotsAlot explicitly initiallized to all zeros;
        arguments may be left unspecified.
    
        Args:
            shape: A tuple of ints. The shape of the underlying ndarray.
    
        Returns:
            SirPlotsAlot
    
        Author / Date:
            Erasmus Alcarin   /   July 13, 2012
    
        """
        s = SirPlotsAlot(shape)
        s[:] = np.zeros(s.shape)
    
        lprint.info("SirPlotsAlot has been populated with zeros.")
    
        return s
    
    
    # Aliases for NewSirPlotsAlot
    splot = NewSirPlotsAlot
    
    
    def NewPsirPlotsAlot(dimensionality = 3, p = None):
        """
        Returns instance of SirPlotsAlot explicitly initiallized to all zeros;
        arguments may be left unspecified.
    
        Args:
            dimensionality: An int. Number of dimensions for array. (optional)
            p: A Parameters object. Simulation parameters. (optional)
    
        Returns:
            SirPlotsAlot
    
        Author / Date:
            Erasmus Alcarin   /   July 13, 2012
    
        """
        lprint.debug("Initializing SirPlotsAlot with initial dimensionality %s..." % dimensionality)
        # NewSirPlotsAlot()
        try:
            import pyrism.parameters as par
        except:
            import sys
            lprint.error("Use of pyrism as non-package detected. You must remain in the pyrism directory.")
            import parameters as par
    
    
        if p == None:
            p = par.Parameters.Instance()
    
        # extract size from parameters file, assuming size nx, ny
        if dimensionality == 2:
            shape = (p.ny, p.nx)
        elif dimensionality == 3:
            shape = (p.nx, p.ny, p.nx)
        else:
            raise NotImplementedError, "dimensionality must be 2 or 3"
            
        # Make and Get object
        s = SirPlotsAlot(shape)
        s[:] = np.zeros(s.shape)
    
        lprint.info("SirPlotsAlot has been populated with zeros.")
    
        return s
            
            
    # Aliases for NewPsirPlotsAlot
    psplot = NewPsirPlotsAlot
    
    
    def returns_SirPlotsAlot(fn):
        """
        A decorator that changes an ndarray to a SirPlotsAlot by
        means of the ndarray view function. (Returns SirPlotsAlot)
    
        """
        def wrapped(*args, **kwargs):
            return fn(*args, **kwargs).view(SirPlotsAlot)
        return wrapped
    
    
    # Aliases for returns_SirPlotsAlot
    returns_splot = returns_SirPlotsAlot