Search code examples
pythonscikit-learngridsearchcv

GridSearchCV runs smoothly when scoring='accuracy', but not when scoring=accuracy_score


When I run the following piece of code in a Jupyter notebook inside Visual Studio Code, it runs smoothly.

from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV

X, y = load_iris(return_X_y=True, as_frame=True)

gs = GridSearchCV(estimator=KNeighborsClassifier(),
                  param_grid=[{'n_neighbors': [3]}],
                  scoring='accuracy')
#                  scoring=accuracy_score)

gs.fit(X, y)

However, if I un-comment the commented line and comment the line above it, and re-run the notebook, I get the following error. Why?

c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_search.py:1052: UserWarning: One or more of the test scores are non-finite: [nan]
  warnings.warn(

Solution

  • Interesting question, had to do some research to find out the reason.

    According to the docu for GridSearchCV, the scoring parameter can indeed accept a string representing the scoring function, or a callable directly. So using a callable directly, as you tried, is technically supported. However, there are specific requirements for how the callable is structured which may not be immediately clear from the documentation alone.

    The problem is you passed the accuracy_score function directly. When you use a scoring function directly as a callable, it must align to the specific requirements expected by GridSearchCV: The callable only takes the model, the test data features X_test, and the true labels y_test as arguments (plus optionally **kwargs to handle additional parameters).

    The standard accuracy_score doesn't fit this pattern directly because it expects two arguments (y_true and y_pred) explicitly. This difference in expectations leads to the error you are getting.

    Solution: Use make_scorer. It takes a metric function and adapts it to meet the expected signature. make_scorer wraps the accuracy_score (or any other metric function) in a way that allows it to be used directly by GridSearchCV or similar utilities by properly handling the prediction step internally and then passing y_true and y_pred to the actual scoring function.

    Use it like this:

    from sklearn.metrics import accuracy_score, make_scorer
    accuracy_scorer = make_scorer(accuracy_score)
    ...
    gs = GridSearchCV(estimator=KNeighborsClassifier(),
                  param_grid=[{'n_neighbors': [3]}],
                  scoring=accuracy_scorer )