Search code examples
pythonoop

Inheriting from too many abstract classes?


I am trying to apply the DRY principle to toy plotting classes as an intellectual exercise for improving my understanding of OOP (currently reading Python Object Oriented programming), but intuitively it seems like the use of a potentially increasingly deep inheritance hierarchy might lead to more complicated code base (especially since I have read that composition seems to be favored over inheritance, see Composition over inheritance wiki). It seems like such a toy library could potentially end up with wayyyyy too many abstract classes like AbstractMonthlyMultiPanelPlot or AbstractSeasonalPlot so on and so forth for arbitrary plot types to accommodate different input data.

Is there a more pythonic way about handling the below problem that perhaps I am missing? Am I violating some sort of design principle that I have either mis-interpreted or perhaps just completely missed?

from abc import abstractmethod, ABC
from numpy import ndarray
from typing import List, Tuple
import matplotlib.pyplot as plt 

class AbstractPlot(ABC)
    @abstractmethod
    def plot(self): 
        raise NotImplementedError

class AbstractMonthlyPlot(AbstractPlot):
    @abstractmethod
    def plot_for_month(ax, data):
        raise NotImplementedError

    @property
    def n_months(self): 
        """number of months in a year"""
        return 12 

    def plot(self, month_to_data: List[Tuple[ndarray]]):
        fig, axs = plt.subplots(self.n_months, 1)
        for month in range(self.n_months):
            self._plot_for_month(ax=axs[month], data=month_to_data[month])

class Contour(AbstractMonthlyPlot):
    def _plot_for_month(self, ax, data):
        ax.contourf(*data)

class Linear(AbstractMonthlyPlot):
    def _plot_for_month(self, ax, data):
        ax.plot(*data)

Solution

  • from abc import abstractmethod, ABC
    from numpy import ndarray
    from typing import List, Tuple
    import matplotlib.pyplot as plt
    
    class Plotter(ABC):
        @abstractmethod
        def plot(self, ax, data):
            raise NotImplementedError
    
    class ContourPlotter(Plotter):
        def plot(self, ax, data):
            ax.contourf(*data)
    
    class LinearPlotter(Plotter):
        def plot(self, ax, data):
            ax.plot(*data)
    
    class MonthlyPlot:
        def __init__(self, plotter: Plotter):
            self.plotter = plotter
    
        @property
        def n_months(self):
            """number of months in a year"""
            return 12
    
        def plot(self, month_to_data: List[Tuple[ndarray]]):
            fig, axs = plt.subplots(self.n_months, 1)
            for month in range(self.n_months):
                self.plotter.plot(ax=axs[month], data=month_to_data[month])
    
    # Usage
    contour_plotter = ContourPlotter()
    monthly_contour_plot = MonthlyPlot(contour_plotter)
    monthly_contour_plot.plot(month_to_data)
    
    linear_plotter = LinearPlotter()
    monthly_linear_plot = MonthlyPlot(linear_plotter)
    monthly_linear_plot.plot(month_to_data)