Search code examples
pythonmatplotlibgrid-search

Cannot reshape array - plot GridSearchCV


The code is simple and there are several questions related with the matter, but my knowledge of python is next to null so I have no idea how this is working. I'm trying to plot my GridSearchCV results. Reading the docs didn't help: https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html

from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV

clf = GridSearchCV(estimator=svm.SVC(),
               param_grid={'C': [1, 10], 'gamma': [0.001, 0.0001], 'kernel': ('linear', 'rbf')}, cv=10, n_jobs=None)
clf.fit(X_train, Y_train)

scores = [x[1] for x in clf.cv_results_]    
print np.array(scores).shape # outputs: (33L,)
scores = np.array(scores).reshape(len([1, 10]), len([0.001, 0.0001]))

for ind, i in enumerate([1, 10]):
    plt.plot([0.001, 0.0001], scores[ind])

plt.legend()
plt.xlabel('Gamma')
plt.ylabel('Mean score')
plt.show()

Output error:

scores = np.array(scores).reshape(len([1, 10]), len([0.001, 0.0001]))
ValueError: cannot reshape array of size 33 into shape (2,2)

Why does this happen and how do I fix it?


Solution

  • Figured it out. Firstly, in order to plot my GridSearchCV, I need to access the mean_test_score field in the clf.cv_results array.

    This makes the data to be printed of size 8, prompting the following error:

    scores = np.array(scores).reshape(len([1, 10]), len([0.001, 0.0001]))
    ValueError: cannot reshape array of size 8 into shape (2,2)
    

    After meddling with the code a bit, this is how it it should look to work fine:

    scores = clf.cv_results_['mean_test_score'].reshape(len([1, 10]), len([0.001, 0.0006, 0.0003, 0.0001]))
    
    for ind, i in enumerate([1, 10]):
        plt.plot([0.001, 0.0006, 0.0003, 0.0001], scores[ind])
    
    plt.legend()
    plt.xlabel('Gamma')
    plt.ylabel('Mean score')
    plt.show()
    

    This makes the 8-length data be ploted in a 2x4 matrix, with a capacity for 8 datas, which was the mistake I had in the past as 2*4 != 33 (bit of a dumb problem, but I lacked basic knowledge on how this works).