I am implementing machine learning algorithms from scratch using python. I have a base class called BaseEstimator
with the following structure:
from __future__ import annotations
from typing import Optional, TypeVar
import numpy as np
from abc import ABC, abstractmethod
T = TypeVar("T", np.ndarray, torch.Tensor)
class BaseEstimator(ABC):
"""Base Abstract Class for Estimators."""
@abstractmethod
def fit(self, X: T, y: Optional[T] = None) -> BaseEstimator:
"""Fit the model according to the given training data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Training vector, where n_samples is the number of samples and
n_features is the number of features.
y : array-like, shape (n_samples,) or (n_samples, n_outputs), optional
Target relative to X for classification or regression;
None for unsupervised learning.
Returns
-------
self : object
Returns self.
"""
@abstractmethod
def predict(self, X: T) -> T:
"""Predict class labels for samples in X.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Samples.
Returns
-------
C : array, shape (n_samples,)
Predicted class label per sample.
"""
class KMeans(BaseEstimator):
def fit(X: T) -> BaseEstimator:
...
def predict(X: T) -> T:
...
class LogisticRegression(BaseEstimator):
def fit(X: T, y: Optional[T] = None) -> BaseEstimator:
...
def predict(X: T) -> T:
...
Now when I implemented the base class, I did not plan properly, some algorithms such as KMeans
are unsupervised and hence do not need y
at all in fit
. Now a quick fix I thought of is to type hint y
as Optional
, so that it can be None
, is that okay? In that case, in KMeans
' fit
method, I will also have to include the y: Optional[T] = None
, which will never be used.
You can inherit two different classes from your base class: ClassificationEstimator and ClusteringEstimator. After that you can move your fit function to both classes, removing the fit function from base class, and remove the y parameter from the fit function in ClusteringEstimator. After that you could inherit your estimator classes from these two base classes.
In this way your classifier classes will not change but clustering classes will be changed and also the usages shall be updated. By the way you could add an overload for fit function in ClusteringEstimator base class, that takes the y function as an optional non usable parameter. But by the way, leaving the y parameter in fit function is not optimal, because it pollute the library interface.