My goal is to create a neural network with a single hidden layer (with ReLU activation) that is able to approximate a simple univariate square root function. I have implemented the network with numpy, also did a gradient check, everything seems to be fine, except for the result: for some reason I can only obtain linear approximations, like this: noisy sqrt approx
Tried changing the hyperparameters, without any success. Any ideas?
import numpy as np
step_size = 1e-6
input_size, output_size = 1, 1
h_size = 10
train_size = 500
x_train = np.abs(np.random.randn(train_size, 1) * 1000)
y_train = np.sqrt(x_train) + np.random.randn(train_size, 1) * 0.5
#initialize weights and biases
Wxh = np.random.randn(input_size, h_size) * 0.01
bh = np.zeros((1, h_size))
Why = np.random.randn(h_size, output_size) * 0.01
by = np.zeros((1, output_size))
for i in range(300000):
#forward pass
h = np.maximum(0, np.dot(x_train, Wxh) + bh1)
y_est = np.dot(h, Why) + by
loss = np.sum((y_est - y_train)**2) / train_size
dy = 2 * (y_est - y_train) / train_size
print("loss: ",loss)
#backprop at output
dWhy = np.dot(h.T, dy)
dby = np.sum(dy, axis=0, keepdims=True)
dh = np.dot(dy, Why.T)
#backprop ReLU non-linearity
dh[h <= 0] = 0
#backprop Wxh, and bh
dWxh = np.dot(x_train.T, dh)
dbh = np.sum(dh1, axis=0, keepdims=True)
Wxh += -step_size * dWxh
bh += -step_size * dbh
Why += -step_size * dWhy
by += -step_size * dby
Edit: It seems the problem was the lack of normalization and the data being non-zero centered. After applying these transformation on the training the data, I have managed to obtain the following result: noisy sqrt2
I can get your code to produce a sort of piecewise linear approximation:
if I zero-centre and normalise your input and output ranges:
# normalise range and domain
x_train -= x_train.mean()
x_train /= x_train.std()
y_train -= y_train.mean()
y_train /= y_train.std()
Plot is produced like so:
x = np.linspace(x_train.min(),x_train.max(),3000)
y = np.dot(np.maximum(0, np.dot(x[:,None], Wxh) + bh), Why) + by
import matplotlib.pyplot as plt
plt.plot(x,y)
plt.show()