Search code examples
pythonnumpymachine-learningscikit-learncross-validation

Choosing best model regarding to k-fold cross validation


I want to take Iris data and choose best logistic model based on GridSearchCV function.

My work so far

import numpy as np
from sklearn import datasets
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression

iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target

# Logistic regression 
reg_log = LogisticRegression()

# Penalties
pen = ['l1', 'l2','none']

#Regularization strength (numbers from -10 up to 10)
C = np.logspace(-10, 10, 100)

# Possibilities for those parameters
parameters= dict(C=C, penalty=pen)

# choosing best model based on 5-fold cross validation
Model = GridSearchCV(reg_log, parameters, cv=5)

# Fitting best model
Best_model = Model.fit(X, y)

And I get a lot of errors. Do you know maybe what I'm doing wrong ?


Solution

  • Since you are choosing different regularization, you can see on the help page:

    The ‘newton-cg’, ‘sag’, and ‘lbfgs’ solvers support only L2 regularization with primal formulation, or no regularization. The ‘liblinear’ solver supports both L1 and L2 regularization, with a dual formulation only for the L2 penalty. The Elastic-Net regularization is only supported by the ‘saga’ solver.

    I am not quite sure if you want to do a grid search with penalization = 'none' and penalization scores. So if you use saga and increase the iteration:

    reg_log = LogisticRegression(solver="saga",max_iter=1000)
    
    pen = ['l1', 'l2']
    C = [0.1,0.001]
    
    parameters= dict(C=C, penalty=pen)
    
    Model = GridSearchCV(reg_log, parameters, cv=5)
    
    Best_model = Model.fit(X, y)
    
    res = pd.DataFrame(Best_model.cv_results_)
    res[['param_C','param_penalty','mean_test_score']]
    
        param_C param_penalty   mean_test_score
    0   0.1 l1  0.753333
    1   0.1 l2  0.833333
    2   0.001   l1  0.333333
    3   0.001   l2  0.700000
    

    It works pretty ok. If you get more errors with your penalization values.. try to look at them and make sure they are not some crazy values.