Search code examples
matlablinear-regressiongradient-descent

Why is my simple MATLAB gradient descend for linear regression not working


I am starting to learn linear regression. I wanted to implement gradient descend by myself. I wrote the code below.

%% Linear regression

close all;

dataset =load('accidents');
data = dataset.hwydata;
x = data(:,14);
y  =data(:,4);
%% Gradient descent
% We want to minimize a cost function and GD achieves that iteratively.
% J(w,b) =(y-y_est)^2

w = 0;
b = 0;
alpha =.000001; % I tried various alphas like 0.01, .1 etc. (Not working)
for i =1: 100
    y_est = w*x + b;
    J = mean((y-y_est).^2)
    temp_w = w + alpha*(mean(x.*(y-w*x-b)));
    temp_b = b + alpha*(mean(y-w*x-b));
    w =temp_w
    b =temp_b
end

For some reason, it is not working. It seems like my algorithm is not converging.

I expected the algorithm to converge nicely because the mean squared error cost function is convex.


Solution

  • The short answer: You need some sort of line search along the direction of the gradient.

    As I mentioned in my inquire, large values of x and y can cause the algorithm to diverge, and that is what is happening in your program, there are x values up to approx. 3e7 and the initial value of the partial derivative w.r.t. w is about 1e10.

    For a very crude and simple line search: compare the current merit function (Jcurrent in the code below) with the one calculated with temp_w, temp_b, and the current alpha (Jnew in the code below).

    If Jnew >= Jcurrent reduce alpha by some factor and recalculate temp_w and temp_b with the new alpha and current gradient. Recalculate Jnew with the updated temp_w and temp_b. Repeat until Jnew < Jcurrent.

    After this line search you have at least two options: 1) Reset alpha to its starting value (alphaOrig in the code) or 2) Keep the current alpha.

    Notice that this line search is very far from optimal. It searches only for a reduction in the merit function and accepts it regardless of how small the reduction is. This causes slow convergence. Let me know if you want me to suggest better line-search approaches.

    JAC

    %% Linear regression
    
    clear; 
    % Delete all figures
    figureList = findobj('type', 'figure');
    if ~isempty(figureList)
        delete(figureList);
    end
        alphaOrig = 1e-4;
        dataset =load('accidents');
        data = dataset.hwydata;
        x = data(:,14);
        y  =data(:,4);
    % ..Check the ranges of x and y    
        fprintf(1, 'max(x) = %g\tmax(y) = %g\n', max(abs(x)), max(abs(y)) );
    %% Gradient descent
    % We want to minimize a cost function and GD achieves that iteratively.
    % J(w,b) =(y-y_est)^2
    
        w = 0;
        b = 0;
        alpha = alphaOrig;
    % ..Initial value of merit function
        Jcurrent = mean((y - w*x - b).^2);
        for i =1: 100
            y_est = w*x + b;
            J = mean((y-y_est).^2);
        % ..Components of the gradient (actually, minus gradient)
            dJdw = mean(x.*(y-w*x-b));
            dJdb = mean(y-w*x-b);
            fprintf(1, '%d: J(%g,%g) = %.8g\t', i, w,b,Jcurrent);
            fprintf(1, 'dJ/dw = %g; dJ/db = %g\n', dJdw, dJdb);
        % ..Crude line search
            while true % Loop not in the original at stackoverflow
                temp_w = w + alpha*dJdw;
                temp_b = b + alpha*dJdb;
                Jnew = mean((y - temp_w*x - temp_b).^2);
                fprintf(1, '\talpha = %g;\tJ(%g,%g) = %.8g; \n', ...
                    alpha, temp_w, temp_b, Jnew);
                if Jnew < Jcurrent
                    break;
                end
                alpha = 0.1*alpha; % <== reduce alpha
                if alpha == 0
                    error("Sorry. It's not going well");
                end
            end
            Jcurrent = Jnew;
    %         alpha = alphaOrig; % This resets alpha
            w =temp_w;
            b =temp_b;
        end
        fmt = "%10s: w = %g; b= %g\n";
        fprintf(1, fmt, "Estimate", w, b);
        fig = figure(1);clf
        plot(x,y, 'linestyle', 'none','marker', 'o');
        hold on
    % ..Show the least-squares fit
        t = [0.9*min(x),1.2*max(x)];
        plot(t,w*t+b, 'linestyle', '--', 'color', 'black');