Search code examples
pandaspython-3.6statsmodels

How to run a multicollinearity test on a pandas dataframe?


I am comparatively new to Python, Stats and using DS libraries, my requirement is to run a multicollinearity test on a dataset having n number of columns and ensure the columns/variables having VIF > 5 are dropped altogether.

I found a code which is,

 from statsmodels.stats.outliers_influence import variance_inflation_factor

    def calculate_vif_(X, thresh=5.0):

        variables = range(X.shape[1])
        tmp = range(X[variables].shape[1])
        print(tmp)
        dropped=True
        while dropped:
            dropped=False
            vif = [variance_inflation_factor(X[variables].values, ix) for ix in range(X[variables].shape[1])]

            maxloc = vif.index(max(vif))
            if max(vif) > thresh:
                print('dropping \'' + X[variables].columns[maxloc] + '\' at index: ' + str(maxloc))
                del variables[maxloc]
                dropped=True

        print('Remaining variables:')
        print(X.columns[variables])
        return X[variables]

But, I do not clearly understand, should I pass the dataset altogether in the X argument's position? If yes, it is not working.

Please help!


Solution

  • I tweaked with the code and managed to achieve the desired result by the following code, with a little bit of Exception Handling -

    def multicollinearity_check(X, thresh=5.0):
        data_type = X.dtypes
        # print(type(data_type))
        int_cols = \
        X.select_dtypes(include=['int', 'int16', 'int32', 'int64', 'float', 'float16', 'float32', 'float64']).shape[1]
        total_cols = X.shape[1]
        try:
            if int_cols != total_cols:
                raise Exception('All the columns should be integer or float, for multicollinearity test.')
            else:
                variables = list(range(X.shape[1]))
                dropped = True
                print('''\n\nThe VIF calculator will now iterate through the features and calculate their respective values.
                It shall continue dropping the highest VIF features until all the features have VIF less than the threshold of 5.\n\n''')
                while dropped:
                    dropped = False
                    vif = [variance_inflation_factor(X.iloc[:, variables].values, ix) for ix in variables]
                    print('\n\nvif is: ', vif)
                    maxloc = vif.index(max(vif))
                    if max(vif) > thresh:
                        print('dropping \'' + X.iloc[:, variables].columns[maxloc] + '\' at index: ' + str(maxloc))
                        # del variables[maxloc]
                        X.drop(X.columns[variables[maxloc]], 1, inplace=True)
                        variables = list(range(X.shape[1]))
                        dropped = True
    
                print('\n\nRemaining variables:\n')
                print(X.columns[variables])
                # return X.iloc[:,variables]
                return X
        except Exception as e:
            print('Error caught: ', e)