Search code examples
pythoninheritanceoverloadingsubclass

Python sub-class that is inherited by multiple classes


Currently I have a class-definition that looks something like this:

class cls_A:

    def __init__(self, a1, a2, a3):
        x = a1 + a2
        y = a2 + a3
        self.__plot__(self)
    # some attributes
    # some methods

    def __plot__(self, parent):
        class plot_fns:
            # plot functions
            def plot1(self):
               plt.plot(parent.x, parent.y)
        self.plot = plot_fns()

A = cls_A(a1, a2, a3)

I do this based on the answer in How to access outer class from an inner class?. This means that I can use the sub-object `plot' to conveniently plot the parameters of the object A. For example:

A.plot.plot1()

Now, I have a similar definition of cls_B, and its instance B. cls_A and cls_B do not have all the same attributes, but the attributes that are related to plotting are the same. Therefore, I would like to have the same sub-object plot functionality for object B. Currently, I have a repeated definition of the function _plot_ in cls_B.

class cls_B:

    def __init__(self, b1, b2, b3):
        x = b1 + b2
        y = b2 + b3
        self.__plot__(self)
    # some attributes
    # some methods

    def __plot__(self, parent):
        class plot_fns:
            # plot functions
            def plot1(self):
               plt.plot(parent.x, parent.y)
        self.plot = plot_fns()

B = cls_B(b1, b2, b3)
B.plot.plot1()

Is there a better way to do this?


Solution

  • To be true to what you are doing without repeating yourself, you can inherit from a SubPlot class that implements the subclass:

    class SubPlot:
        def __init__(self):
            self.__plot__(self)
    
        def __plot__(self, parent):
            class plot_fns:
                # plot functions
                def plot1(self):
                    print("plotting", parent.x, parent.y)
            self.plot = plot_fns()
    
    
    class PlotA(SubPlot):
        def __init__(self, x, y):
            super().__init__()
            self.x = x
            self.y = y
    
    class PlotB(SubPlot):
        def __init__(self, x, y):
            super().__init__()
            self.x = x
            self.y = y
    
    
    PlotA(1, 2).plot.plot1()
    PlotB(3, 4).plot.plot1()