I have implemented a very simple linear regression with gradient descent algorithm in JavaScript, but after consulting multiple sources and trying several things, I cannot get it to converge.
The data is absolutely linear, it's just the numbers 0 to 30 as inputs with x*3 as their correct outputs to learn.
This is the logic behind the gradient descent:
train(input, output) {
const predictedOutput = this.predict(input);
const delta = output - predictedOutput;
this.m += this.learningRate * delta * input;
this.b += this.learningRate * delta;
}
predict(x) {
return x * this.m + this.b;
}
I took the formulas from different places, including:
I have already tried:
y = x * 3
)y = x * 3 + 2
)Still, the weights (this.b
and this.m
) do not approach any of the data values, and they diverge into infinity.
I'm obviously doing something wrong, but I cannot figure out what it is.
Update: Here's a little bit more context that may help figure out what my problem is exactly:
I'm trying to model a simple approximation to a linear function, with online learning by a linear regression pseudo-neuron. With that, my parameters are:
this.m
, this.b
]x
, 1
]z(x) = x
As such, my net will be expressed by y = this.m * x + this.b * 1
, simulating the data-driven function that I want to approximate (y = 3 * x
).
What I want is for my network to "learn" the parameters this.m = 3
and this.b = 0
, but it seems I get stuck at a local minima.
My error function is the mean-squared error:
error(allInputs, allOutputs) {
let error = 0;
for (let i = 0; i < allInputs.length; i++) {
const x = allInputs[i];
const y = allOutputs[i];
const predictedOutput = this.predict(x);
const delta = y - predictedOutput;
error += delta * delta;
}
return error / allInputs.length;
}
My logic for updating my weights will be (according to the sources I've checked so far) wi -= alpha * dError/dwi
For the sake of simplicity, I'll call my weights this.m
and this.b
, so we can relate it back to my JavaScript code. I'll also call y^
the predicted value.
From here:
error = y - y^
= y - this.m * x + this.b
dError/dm = -x
dError/db = 1
And so, applying that to the weight correction logic:
this.m += alpha * x
this.b -= alpha * 1
But this doesn't seem correct at all.
I finally found what's wrong, and I'm answering my own question in hopes it will help beginners in this area too.
First, as Sascha said, I had some theoretical misunderstandings. It may be correct that your adjustment includes the input value verbatim, but as he said, it should already be part of the gradient. This all depends on your choice of the error function.
Your error function will be the measure of what you use to measure how off you were from the real value, and that measurement needs to be consistent. I was using mean-squared-error as a measurement tool (as you can see in my error
method), but I was using a pure-absolute error (y^ - y
) inside of the training method to measure the error. Your gradient will depend on the choice of this error function. So choose only one and stick with it.
Second, simplify your assumptions in order to test what's wrong. In this case, I had a very good idea what the function to approximate was (y = x * 3
) so I manually set the weights (this.b
and this.m
) to the right values and I still saw the error diverge. This means that weight initialization was not the problem in this case.
After searching some more, my error was somewhere else: the function that was feeding data into the network was mistakenly passing a 3
hardcoded value into the predicted output (it was using a wrong index in an array), so the oscillation I saw was because of the network trying to approximate to y = 0 * x + 3
(this.b = 3
and this.m = 0
), but because of the small learning rate and the error in the error function derivative, this.b
wasn't going to get near to the right value, making this.m
making wild jumps to adjust to it.
Finally, keep track of the error measurement as your network trains, so you can have some insight into what's going on. This helps a lot to identify a difference between simple overfitting, big learning rates and plain simple mistakes.