Search code examples
pythonscikit-learnisinstance

How to check whether an sklearn estimator is a scaler?


I'm writing a function that needs to determine whether an object passed to it is an imputer (can check with isinstance(obj, _BaseImputer)), a scaler, or something else.

While all imputers have a common base class that identifies them as imputers, scalers do not. I found that all scalers in sklearn.preprocessing._data inherit (OneToOneFeatureMixin, TransformerMixin, BaseEstimator), so I could check if they are instances of all of them. However that could generate false positives (not sure which other object may inherit the same base classes). It doesn't feel very clean or pythonic either.

I was also thinking of checking whether the object has the .inverse_transform() method. However, not only scalers have that, a SimpleImputer (and maybe other objects) have also.

How can I easily check if my object is a scaler?


Solution

  • Unfortunately, the cleanest way to do this is to check each scaler type individually, any other check will potentially let through non-scaler objects as well. Nevertheless, I'll offer some "hack-jobs" too.

    The most failsafe solution is to import your scalers and then check if your object is any of these scalers or not.

    from sklearn.preprocessing import MinMaxScaler, RobustScaler # ... other scalers your code base uses
    
    SCALER_TYPES = [MinMaxScaler, RobustScaler] # Extend list if needed
    
    if any([isinstance(YourObject, scaler_type) for scaler_type in SCALER_TYPES]):
        # Do something
        pass
    else:
        # Do something else
        pass
    

    Now, if you want something that catches them all without listing all the scalers you use in your code, you could rely on private properties of the scaler objects. These are private for a good reason though, and are subject to change without notice even between patch versions, so nothing at all guarantees that your code will work if you update sklearn to a new version. You could rely on the string representation (__repr__) of the object to check if it contains Scaler. This is how you can do it:

    if 'Scaler' in str(YourObject):
        # Do something
        pass
    else:
        # Do something else
        pass
    

    or

    if 'Scaler' in YourObject.__repr__():
        # Do something
        pass
    else:
        # Do something else
        pass
    

    This will let through anything that has Scaler in its string representation though, so you are definitely better off with being explicit and defining your list of scalers.