This warning pops up 10 times. The line number varies between 456 and 305:
C:\Users\foo\Anaconda3\lib\site-packages\scipy\optimize\_linesearch.py:456: LineSearchWarning: The line search algorithm did not converge
warn('The line search algorithm did not converge', LineSearchWarning)
I'm running a grid search with these parameters:
logistic_regression_grid = {
"class_weight": ["balanced"],
"max_iter": [100000],
"solver": ["lbfgs", "liblinear", "newton-cg", "newton-cholesky", "sag", "saga"],
"random_state": [0]
}
So, the question is which solver is throwing the warning? Is it possible to determine that?
I used the iris set and I set max_iter=10
to purposefully induce a convergence warning. Since you are interested only in the solvers I looped over the solvers without using grid search and I was able to print which solver does not converge using the warnings
library and the sklearn.exceptions
package. Here is my code:
import warnings
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# Your logistic regression grid
logistic_regression_grid = {
"class_weight": ["balanced"],
"max_iter": [100000],
"solver": ["lbfgs", "liblinear", "newton-cg", "sag", "saga"],
"random_state": [0]
}
# Load the Iris dataset
iris = load_iris()
X, y = iris.data, iris.target
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Loop over the solvers and capture warnings
for solver in logistic_regression_grid["solver"]:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Fit logistic regression model with the current solver
model = LogisticRegression(class_weight="balanced", max_iter=10, solver=solver, random_state=0)
model.fit(X_train, y_train)
# Check if any warning was generated
if any(issubclass(warning.category, ConvergenceWarning) for warning in w):
print(f"Solver '{solver}' did not converge.")
Here is the output I get:
Solver 'lbfgs' did not converge.
Solver 'newton-cg' did not converge.
Solver 'sag' did not converge.
Solver 'saga' did not converge.