Search code examples
pythonmatplotlibscikit-learnlogistic-regression

Feature importance using gridsearchcv for logistic regression


I've trained a logistic regression model like this:

reg = LogisticRegression(random_state = 40)
cvreg = GridSearchCV(reg, param_grid={'C':[0.05,0.1,0.5],
                                      'penalty':['none','l1','l2'],
                                      'solver':['saga']},
                     cv = 5)
cvreg.fit(X_train, y_train)

Now to show the feature's importance I've tried this code, but I don't get the names of the coefficients in the plot:

from matplotlib import pyplot

importance = cvreg.best_estimator_.coef_[0]
pyplot.bar([x for x in range(len(importance))], importance)
pyplot.show()

enter image description here

Obviously, the plot isn't very informative. How do I add the names of the coefficients to the x-axis?

The importance of the coeff is:

cvreg.best_estimator_.coef_
array([[1.10303023e+00, 7.48816905e-01, 4.27705027e-04, 6.01404570e-01]])

Solution

  • The coefficients correspond to the columns of X_train, so pass in the X_train names instead of range(len(importance)).

    Assuming X_train is a pandas dataframe:

    import matplotlib.pyplot as plt
    
    features = X_train.columns
    importance = cvreg.best_estimator_.coef_[0]
    
    plt.bar(features, importance)
    plt.show()
    

    Note that if X_train is just a numpy array without column names, you will have to define the features list based on your own data dictionary.

    sample output