Search code examples
pythonmachine-learningscikit-learnlinear-regressionhyperparameters

Using .set_params() function for LinearRegression


I recently started working on Machine Learning with Linear Regression. I have used a LinearRegression (lr) to predict some values. Indeed, my predictions were bad, and I was asked to change the hyperparameters to obtain better results.

I used the following command to obtain the hyperparameters:

lr.get_params().keys() 
lr.get_params()

and obtained the following:

 'copy_X': True,
 'fit_intercept': True,
 'n_jobs': None,
 'normalize': False,
 'positive': False}

and

dict_keys(['copy_X', 'fit_intercept', 'n_jobs', 'normalize', 'positive'])

Now, this is where issues started to raise. I have tried to find the correct syntax to use the .set_params() function, but every answer seemed outside my comprehension.

I have tried to assign a positional arguments since commands such as lr.set_params('normalize'==True) returned

TypeError: set_params() takes 1 positional argument but 2 were given

and lr.set_params(some_params = {'normalize'}) returned

ValueError (`ValueError: Invalid parameter some_params for estimator LinearRegression(). Check the list of available parameters with estimator.get_params().keys().

Can someone provide a simple explanation of how this function works?


Solution

  • The correct syntax is set_params(**params) where params is a dictionary containing the estimator's parameters, see the scikit-learn documentation.

    from sklearn.linear_model import LinearRegression
    
    reg = LinearRegression()
    
    reg.get_params()
    # {'copy_X': True,
    #  'fit_intercept': True,
    #  'n_jobs': None,
    #  'normalize': False,
    #  'positive': False}
    
    reg.set_params(**{
        'copy_X': False,
        'fit_intercept': False,
        'n_jobs': -1,
        'normalize': True,
        'positive': True
    })
    
    reg.get_params()
    # {'copy_X': False,
    #  'fit_intercept': False,
    #  'n_jobs': -1,
    #  'normalize': True,
    #  'positive': True}