Search code examples
neural-networkregressionsupervised-learning

ANN: Approximating non-linear function with neural network


I am learning to build neural networks for regression problems. It works well approximating linear functions. Setup with 1-5–1 units with linear activation functions in hidden and output layers does the trick and results are fast and reliable. However, when I try to feed it simple quadratic data (f(x) = x*x) here is what happens:

With linear activation function, it tries to fit a linear function through dataset

enter image description here

And with TANH function it tries to fit a a TANH curve through the dataset.

enter image description here

This makes me believe that the current setup is inherently unable to learn anything but a linear relation, since it's repeating the shape of activation function on the chart. But this may not be true because I've seen other implementations learn curves just perfectly. So I may be doing something wrong. Please provide your guidance.


About my code

My weights are randomized (-1, 1) inputs are not normalized. Dataset is fed in random order. Changing learning rate or adding layers, does not change the picture much.

I've created a jsfiddle,

the place to play with is this function:

function trainingSample(n) {
    return [[n], [n]]; 
}

It produces a single training sample: an array of an input vector array and a target vector array. In this example it produces an f(x)=x function. Modify it to be [[n], [n*n]] and you've got a quadratic function.

The play button is at the upper right, and there also are two input boxes to manually input these values. If target (right) box is left empty, you can test the output of the network by feedforward only.

There is also a configuration file for the network in the code, where you can set learning rate and other things. (Search for var Config)


Solution

  • It's occurred to me that in the setup I am describing, it is impossible to learn non–linear functions, because of the choice of features. Nowhere in forward pass we have input dependency of power higher than 1, that's why I am seeing a snapshot of my activation function in the output. Duh.