Search code examples
scikit-learnscipylogistic-regression

Which solver is throwing this warning?


This warning pops up 10 times. The line number varies between 456 and 305:

C:\Users\foo\Anaconda3\lib\site-packages\scipy\optimize\_linesearch.py:456: LineSearchWarning: The line search algorithm did not converge
  warn('The line search algorithm did not converge', LineSearchWarning)

I'm running a grid search with these parameters:

logistic_regression_grid = {
    "class_weight": ["balanced"], 
    "max_iter":     [100000],
    "solver":       ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"],
    "random_state": [0]
}

So, the question is which solver is throwing the warning? Is it possible to determine that?


Solution

  • I used the iris set and I set max_iter=10 to purposefully induce a convergence warning. Since you are interested only in the solvers I looped over the solvers without using grid search and I was able to print which solver does not converge using the warnings library and the sklearn.exceptions package. Here is my code:

    import warnings
    import numpy as np
    from sklearn.linear_model import LogisticRegression
    from sklearn.exceptions import ConvergenceWarning
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    
    # Your logistic regression grid
    
    logistic_regression_grid = {
        "class_weight": ["balanced"], 
        "max_iter":     [100000],
        "solver":       ["lbfgs", "liblinear", "newton-cg", "sag", "saga"],
        "random_state": [0]
    }
    
    # Load the Iris dataset
    
    iris = load_iris()
    X, y = iris.data, iris.target
    
    # Split the data into training and testing sets
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    
    # Loop over the solvers and capture warnings
    
    for solver in logistic_regression_grid["solver"]:
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
    
            # Fit logistic regression model with the current solver
    
            model = LogisticRegression(class_weight="balanced", max_iter=10, solver=solver, random_state=0)
            model.fit(X_train, y_train)
    
            # Check if any warning was generated
    
            if any(issubclass(warning.category, ConvergenceWarning) for warning in w):
                print(f"Solver '{solver}' did not converge.")
    
    

    Here is the output I get:

    Solver 'lbfgs' did not converge.
    Solver 'newton-cg' did not converge.
    Solver 'sag' did not converge.
    Solver 'saga' did not converge.