Search code examples
pythonoopdesign-patterns

Child class with different signatures, how to reasonable resolve it without breaking the code?


I am implementing machine learning algorithms from scratch using python. I have a base class called BaseEstimator with the following structure:

from __future__ import annotations

from typing import Optional, TypeVar
import numpy as np
from abc import ABC, abstractmethod

T = TypeVar("T", np.ndarray, torch.Tensor)


class BaseEstimator(ABC):
    """Base Abstract Class for Estimators."""

    @abstractmethod
    def fit(self, X: T, y: Optional[T] = None) -> BaseEstimator:
        """Fit the model according to the given training data.

        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            Training vector, where n_samples is the number of samples and
            n_features is the number of features.

        y : array-like, shape (n_samples,) or (n_samples, n_outputs), optional
            Target relative to X for classification or regression;
            None for unsupervised learning.

        Returns
        -------
        self : object
            Returns self.
        """

    @abstractmethod
    def predict(self, X: T) -> T:
        """Predict class labels for samples in X.

        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            Samples.

        Returns
        -------
        C : array, shape (n_samples,)
            Predicted class label per sample.
        """

class KMeans(BaseEstimator):

    def fit(X: T) -> BaseEstimator:
        ...

    def predict(X: T) -> T:
        ...


class LogisticRegression(BaseEstimator):
    def fit(X: T, y: Optional[T] = None) -> BaseEstimator:
        ...

    def predict(X: T) -> T:
        ...

Now when I implemented the base class, I did not plan properly, some algorithms such as KMeans are unsupervised and hence do not need y at all in fit. Now a quick fix I thought of is to type hint y as Optional, so that it can be None, is that okay? In that case, in KMeans' fit method, I will also have to include the y: Optional[T] = None, which will never be used.


Solution

  • You can inherit two different classes from your base class: ClassificationEstimator and ClusteringEstimator. After that you can move your fit function to both classes, removing the fit function from base class, and remove the y parameter from the fit function in ClusteringEstimator. After that you could inherit your estimator classes from these two base classes.

    In this way your classifier classes will not change but clustering classes will be changed and also the usages shall be updated. By the way you could add an overload for fit function in ClusteringEstimator base class, that takes the y function as an optional non usable parameter. But by the way, leaving the y parameter in fit function is not optimal, because it pollute the library interface.