Search code examples
machine-learningscikit-learnclassificationmeanimputation

scikit-learn impute mean of feature within groups of nominal value in another feature


I want to impute the mean of a feature but only calculate the mean based off other examples that have the same category/nominal value in another column and I was wondering if this was possible using scikit-learn's Imputer class? It would just make it easier to add into a pipeline that way.

For example:

Using the Titanic dataset from kaggle: source

How would I go about imputing the mean fare per pclass. The thinking behind it being that people in different classes would have large differences in cost between tickets.

Update: After discussion with some people, the phrase I should have used was "imputing the mean within class".

I've looked into Vivek's comment below and will construct a generic pipeline function when I get time to do what I want :) I have a good idea of how to do it and will post as an answer when it's finished.


Solution

  • So below is a pretty simple approach to my question that was just meant to handle the means of things. A more robust implementation would probably involve utilising the Imputer class from scikit learn which would mean it could also do mode, median, etc. and would be better at dealing with sparse/dense matrices.

    This is based on Vivek Kumar's comment on the original question which suggested splitting the data into stacks and imputing it that way then re-assembling them.

    import numpy as np
    from sklearn.base import BaseEstimator, TransformerMixin
    
    class WithinClassMeanImputer(BaseEstimator, TransformerMixin):
        def __init__(self, replace_col_index, class_col_index = None, missing_values=np.nan):
            self.missing_values = missing_values
            self.replace_col_index = replace_col_index
            self.y = None
            self.class_col_index = class_col_index
    
        def fit(self, X, y = None):
            self.y = y
            return self
    
        def transform(self, X):
            y = self.y
            classes = np.unique(y)
            stacks = []
    
            if len(X) > 1 and len(self.y) = len(X):
                if( self.class_col_index == None ):
                    # If we're using the dependent variable
                    for aclass in classes:
                        with_missing = X[(y == aclass) & 
                                            (X[:, self.replace_col_index] == self.missing_values)]
                        without_missing = X[(y == aclass) & 
                                                (X[:, self.replace_col_index] != self.missing_values)]
    
                        column = without_missing[:, self.replace_col_index]
                        # Calculate mean from examples without missing values
                        mean = np.mean(column[without_missing[:, self.replace_col_index] != self.missing_values])
    
                        # Broadcast mean to all missing values
                        with_missing[:, self.replace_col_index] = mean
    
                        stacks.append(np.concatenate((with_missing, without_missing)))
                else:
                    # If we're using nominal values within a binarised feature (i.e. the classes
                    # are unique values within a nominal column - e.g. sex)
                    for aclass in classes:
                        with_missing = X[(X[:, self.class_col_index] == aclass) & 
                                            (X[:, self.replace_col_index] == self.missing_values)]
                        without_missing = X[(X[:, self.class_col_index] == aclass) & 
                                                (X[:, self.replace_col_index] != self.missing_values)]
    
                        column = without_missing[:, self.replace_col_index]
                        # Calculate mean from examples without missing values
                        mean = np.mean(column[without_missing[:, self.replace_col_index] != self.missing_values])
    
                        # Broadcast mean to all missing values
                        with_missing[:, self.replace_col_index] = mean
                        stacks.append(np.concatenate((with_missing, without_missing)))
    
                if len(stacks) > 1 :
                    # Reassemble our stacks of values
                    X = np.concatenate(stacks)
    
            return X