Search code examples
pythonscikit-learngrid-search

Failure to reproduce GridSearch from sklearn in python


I am trying to do something similar to GridSearch in sklearn: I want to get a list of three models, where all parameters are fixed except for C corresponding to the 1, 10, and 100 in each model. I have the following two functions.

def params_GridSearch(dic_params):
    keys, values = dic_params.keys(), dic_params.values()
    lst_params = []
    for vs in itertools.product(*values):
        lst_params.append( {k:v for k,v in zip(keys,vs)} )
    return lst_params

def models_GridSearch(model, dic_params):
    models = [ model.set_params(**params) for params in params_GridSearch(dic_params) ]
    return models

I then build a model and specify a dictionary of parameters.

from sklearn.svm import SVC
model = SVC()
dic = {'C': [1,10,100]}

And generate the models using the functions I just defined.

models = models_GridSearch(model, dic)

However, the outcome is the same model (using the last parameter, i.e. 100) being repeated 3 times. It seems there is some aliasing going on.


Solution

  • model refers to the same object throughout each iteration of the list comprehension in model_GridSearch, so you're just assigning a C value 3 times to the same object. You can do a few different things to fix this: you could a copy of the object using the copy module, or pass in the class into the models_GridSearch function instead of an instance, and instantiate an object on each iteration. You could also refactor your code in various ways to fix things. It all depends on your goals.

    Copy method:

    import copy
    
    def models_GridSearch(model, dic_params):
        models = [ copy.deepcopy(model).set_params(**params) for params in params_GridSearch(dic_params) ]
        return models
    

    Pass in class:

    def models_GridSearch(Model, dic_params):
        models = [ Model().set_params(**params) for params in params_GridSearch(dic_params) ]
        return models
    
    
    from sklearn.svm import SVC
    Model = SVC
    dic = {'C': [1,10,100]}
    
    models = models_GridSearch(Model, dic)
    print models