Search code examples
pythonscikit-learnsvmgrid-search

GridSearchCV - access to predicted values across tests?


Is there a way to get access to the predicted values calculated within a GridSearchCV process?

I'd like to be able to plot the predicted y values against their actual values (from the test/validation set).

Once the grid search is complete, I can fit it against some other data using

 ypred = grid.predict(xv)

but I'd like to be able to plot the values calculated during the grid search. Maybe there's a way of saving the points as a pandas dataframe?

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV, KFold, 
cross_val_score, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.svm import SVR

scaler = StandardScaler()
svr_rbf = SVR(kernel='rbf')
pipe = Pipeline(steps=[('scaler', scaler), ('svr_rbf', svr_rbf)])
grid = GridSearchCV(pipe, param_grid=parameters, cv=splits, refit=True, verbose=3, scoring=msescorer, n_jobs=4)
grid.fit(xt, yt)

Solution

  • One solution is to make a custom scorer and save an argument it receives into a global variable:

    from sklearn.grid_search import GridSearchCV
    from sklearn.svm import SVR
    from sklearn.metrics import mean_squared_error,make_scorer
    
    X, y = np.random.rand(2,200)
    clf = SVR()
    
    ys = []
    
    def MSE(y_true,y_pred):
        global ys
        ys.append(y_pred)
        mse = mean_squared_error(y_true, y_pred)
        return mse
    
    def scorer():
        return make_scorer(MSE, greater_is_better=False)
    
    n_splits = 3 
    cv = GridSearchCV(clf, {'degree':[1,2,3]}, scoring=scorer(), cv=n_splits)
    cv.fit(X.reshape(-1, 1), y)
    

    Then we need to collect every split into a full array:

    idxs = range(0, len(ys)+1, n_splits)
    #e.g. [0, 3, 6, 9]
    #collect every n_split elements into a single list
    new = [ys[j[0]+1:j[1]] for j in zip(idxs,idxs[1:])]
    #summing every such list
    ys = [reduce(lambda x,y:np.concatenate((x,y), axis=0), i) for i in new]