I am trying to apply the DRY principle to toy plotting classes as an intellectual exercise for improving my understanding of OOP (currently reading Python Object Oriented programming), but intuitively it seems like the use of a potentially increasingly deep inheritance hierarchy might lead to more complicated code base (especially since I have read that composition seems to be favored over inheritance, see Composition over inheritance wiki). It seems like such a toy library could potentially end up with wayyyyy too many abstract classes like AbstractMonthlyMultiPanelPlot
or AbstractSeasonalPlot
so on and so forth for arbitrary plot types to accommodate different input data.
Is there a more pythonic way about handling the below problem that perhaps I am missing? Am I violating some sort of design principle that I have either mis-interpreted or perhaps just completely missed?
from abc import abstractmethod, ABC
from numpy import ndarray
from typing import List, Tuple
import matplotlib.pyplot as plt
class AbstractPlot(ABC)
@abstractmethod
def plot(self):
raise NotImplementedError
class AbstractMonthlyPlot(AbstractPlot):
@abstractmethod
def plot_for_month(ax, data):
raise NotImplementedError
@property
def n_months(self):
"""number of months in a year"""
return 12
def plot(self, month_to_data: List[Tuple[ndarray]]):
fig, axs = plt.subplots(self.n_months, 1)
for month in range(self.n_months):
self._plot_for_month(ax=axs[month], data=month_to_data[month])
class Contour(AbstractMonthlyPlot):
def _plot_for_month(self, ax, data):
ax.contourf(*data)
class Linear(AbstractMonthlyPlot):
def _plot_for_month(self, ax, data):
ax.plot(*data)
from abc import abstractmethod, ABC
from numpy import ndarray
from typing import List, Tuple
import matplotlib.pyplot as plt
class Plotter(ABC):
@abstractmethod
def plot(self, ax, data):
raise NotImplementedError
class ContourPlotter(Plotter):
def plot(self, ax, data):
ax.contourf(*data)
class LinearPlotter(Plotter):
def plot(self, ax, data):
ax.plot(*data)
class MonthlyPlot:
def __init__(self, plotter: Plotter):
self.plotter = plotter
@property
def n_months(self):
"""number of months in a year"""
return 12
def plot(self, month_to_data: List[Tuple[ndarray]]):
fig, axs = plt.subplots(self.n_months, 1)
for month in range(self.n_months):
self.plotter.plot(ax=axs[month], data=month_to_data[month])
# Usage
contour_plotter = ContourPlotter()
monthly_contour_plot = MonthlyPlot(contour_plotter)
monthly_contour_plot.plot(month_to_data)
linear_plotter = LinearPlotter()
monthly_linear_plot = MonthlyPlot(linear_plotter)
monthly_linear_plot.plot(month_to_data)