Search code examples

Gradient descent for linear regression is not working

I tried making a program for linear regression using gradient descent for some sample data. The theta values that I get do not give the best fit for the data. I have already normalized the data.

public class OneVariableRegression {

    public static void main(String[] args) {

        double x1[] = {-1.605793084, -1.436762233, -1.267731382, -1.098700531, -0.92966968, -0.760638829, -0.591607978, -0.422577127, -0.253546276, -0.084515425, 0.084515425, 0.253546276, 0.422577127, 0.591607978, 0.760638829, 0.92966968, 1.098700531, 1.267731382, 1.436762233, 1.605793084};  
        double y[] = {0.3, 0.2, 0.24, 0.33, 0.35, 0.28, 0.61, 0.38, 0.38, 0.42, 0.51, 0.6, 0.55, 0.56, 0.53, 0.61, 0.65, 0.68, 0.74, 0.87};
        double theta0 = 0.5;
        double theta1 = 0.5;
        double temp0;
        double temp1;
        double alpha = 1.5;
        double m = x1.length;
        double derivative0 = 0;
        double derivative1 = 0;
        do {
                    for (int i = 0; i < x1.length; i++) {
            derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
            derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];
          temp0 = theta0 - (alpha * derivative0);
          temp1 = theta1 - (alpha * derivative1);
          theta0 = temp0;
          theta1 = temp1;
          //System.out.println("Derivative0 = " + derivative0);
          //System.out.println("Derivative1 = " + derivative1);
        while (derivative0 > 0.0001 || derivative1 > 0.0001);
        System.out.println("theta 0 = " + theta0);
        System.out.println("theta 1 = " + theta1);


  • Yes, it's convex.

    The derivative you're using comes from the squared error function, which is convex, hence accepts no local minimums other than the one global minimum. (In fact, this type of problem can even accepts a closed-form solution called the normal equation, it's just not numerically tractable for large problems, hence the use of gradient descent)

    And the correct answer is around theta0 = 0.4895 and theta1 = 0.1652, this is trivial to check on any statistical computing environment. (See bottom of answer if you're skeptical)

    Below I point out the mistakes in your code, after fixing the mistakes, you'll get the correct answer above within 4 decimals places.

    Problems with your implementation:

    So you are right to expect it to converge global minimum, but you have problems in the implementation

    Each time you recalculate the derivative_i, you forgot to reset it to 0 (what you were doing was accumulating the derivative across iterations in the do{}while()

    You need this in the do while loop

    do {                                                                    
       derivative0 = 0;                                                      
       derivative1 = 0;

    Next is this

    derivative0 = (derivative0 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m);
    derivative1 = (derivative1 + (theta0 + (theta1 * x1[i]) - y[i])) * (1/m) * x1[i];

    The x1[i] factor should be applied to the (theta0 + (theta1 * x1[i]) - y[i])) alone.

    Your attempt is slightly confusing, so let's write it in a clearer manner as below, which is a lot closer to its mathematical equation (1/m)sum(y_hat_i - y_i)x_i:

    // You need fresh vars, don't accumulate the derivatives across gradient descent iterations
    derivative0 = 0;                                                      
    derivative1 = 0;
    for (int i = 0; i < m; i++) {                       
        derivative0 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i]);          
        derivative1 += (1/m) * (theta0 + (theta1 * x1[i]) - y[i])*x1[i];    

    That should get you close enough, however, I find your learning rate alpha to be a tad big. When it's too big, your gradient descent will have trouble zeroing in no your global optimum, it will hang around there, but won't quite be there.

    double alpha = 0.5;                                                 

    Confirm the results

    Run it and compare it to the answer from a statistics software

    Here's a gist on github of your .java file.

    ➜  ~ javac && java OneVariableRegression                                                                                                                           
    theta 0 = 0.48950064086914064
    theta 1 = 0.16520139788757973

    I compared it with R

    > x
     [1] -1.60579308 -1.43676223 -1.26773138 -1.09870053 -0.92966968 -0.76063883
     [7] -0.59160798 -0.42257713 -0.25354628 -0.08451543  0.08451543  0.25354628
    [13]  0.42257713  0.59160798  0.76063883  0.92966968  1.09870053  1.26773138
    [19]  1.43676223  1.60579308
    > y
     [1] 0.30 0.20 0.24 0.33 0.35 0.28 0.61 0.38 0.38 0.42 0.51 0.60 0.55 0.56 0.53
    [16] 0.61 0.65 0.68 0.74 0.87
    > lm(y ~ x)
    lm(formula = y ~ x)
    (Intercept)            x  
         0.4895       0.1652  

    Now your code gives the correct answer to at least 4 decimals.