I have written code for a neural network but when I train my network it does not produce the desired output (network not learning and sometimes NaN values when training). What wrong with my back propagation algorithm? Attached below is how I derived the formula for weight and bias gradients respectively. Full code can be found here.
public double[][] predict(double[][] input) {
if(input.length != this.activations.get(0).length || input[0].length != this.activations.get(0)[0].length) {
throw new IllegalArgumentException("Prediction Error!");
}
this.activations.set(0, input);
for(int i = 1; i < this.activations.size(); i++) {
this.activations.set(i, this.sigmoid(this.add(this.multiply(this.weights.get(i-1), this.activations.get(i-1)), this.biases.get(i-1))));
}
return this.activations.get(this.n-1);
}
public void train(double[][] input, double[][] target) {
//calculate activations
this.predict(input);
//calculate weight gradients
for(int l = 0; l < this.weightGradients.size(); l++) {
for(int i = 0; i < this.weightGradients.get(l).length; i++) {
for(int j = 0; j < this.weightGradients.get(l)[0].length; j++) {
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
}
}
}
//calculated bias gradients
for(int l = 0; l < this.biasGradients.size(); l++) {
for(int i = 0; i < this.biasGradients.get(l).length; i++) {
for(int j = 0; j < this.biasGradients.get(l)[0].length; j++) {
this.biasGradients.get(l)[i][j] = this.gradientOfBias(l, i, j, target);
}
}
}
//apply gradient
for(int i = 0; i < this.weights.size(); i++) {
this.weights.set(i, this.subtract(this.weights.get(i), this.weightGradients.get(i)));
}
for(int i = 0; i < this.biases.size(); i++) {
this.biases.set(i, this.subtract(this.biases.get(i), this.biasGradients.get(i)));
}
}
private double gradientOfWeight(int l, int i, int j, double[][] t) { //when referring to A, use l+1 because A[0] is input vector, n-1 because n starts at 1
double z = (this.activations.get(l + 1)[i][0] * (1.0 - this.activations.get(l + 1)[i][0]) * this.activations.get(l)[j][0]);
if((l + 1) < (this.n - 1)) {
double sum = 0.0;
for(int k = 0; k < this.weights.get(l + 1).length; k++) {
sum += this.gradientOfWeight(l + 1, k, i, t)*this.weights.get(l + 1)[k][i];
}
return ((z * sum) / this.activations.get(l + 1)[i][0]);
} else if((l + 1) == (this.n - 1)) {
return 2.0 * (this.activations.get(l + 1)[i][0] - t[i][0]) * z;
}
throw new IllegalArgumentException("Weight Gradient Calculation Error!");
}
The amount of math that's involved in this question combined with the lack of data/reproduction of your code makes it nearly impossible to answer the original question of "where is my NaN".
Instead, I would propose you reconsider this question to be a simpler one, "How can I tell where a value like NaN is coming from in my code".
If you can run your code in an IDE, most of them will support conditional breakpoints. i.e. breakpoints that will pause your code whenever a variable reaches a value. In your case, I would recommend running your code in your preferred IDE with a conditional breakpoint detecting a value is NaN.
You can read more about how you would set it in this SO post where the topic of NaN double checking is nicely mentioned in this thread: Eclipse Debugger doesn't stop at conditional breakpoint
Another follow-up consideration is to think WHERE you need to put these breakpoints. The short answer is to put them wherever a double is computed, because any of these computations might introduce the NaN.
To that effect, I make the following two recommendations:
First, put a breakpoint where you currently compute doubles to see if NaN's come from these computations. That would be these two variables:
double z = ...
double sum = ...
Second, refactor your calls to gradientOfWeight to return into a temporary variable, and then put a similar breakpoint on THOSE interrim computations.
So instead of
this.weightGradients.get(l)[i][j] = this.gradientOfWeight(l, i, j, target);
You would have:
double interrimComputationToListenForNaNon = this.gradientOfWeight(l, i, j, target);
this.weightGradients.get(l)[i][j] = interrimComputationToListenForNaNon;
Having these interrim variables is more of a convenience to give you an easy way to monitor the computation without changing the call in any significant way. There may be a smarter way to do that without requiring an interrim variable, but this one seems to easiest to monitor and explain.