I tried to build a Neural network from scratch using JAX numpy moldule. In the training phase it seems that the accuracy of the model doesn't improve at all. Here is the code.
import jax
import jax.numpy as jnp
def init_params():
key = jax.random.PRNGKey(0)
W1 = jax.random.uniform(key, (10, 784), minval=-0.5, maxval=0.5)
b1 = jax.random.uniform(key, (10, 1), minval=-0.5, maxval=0.5)
W2 = jax.random.uniform(key, (10, 10), minval=-0.5, maxval=0.5)
b2 = jax.random.uniform(key, (10, 1), minval=-0.5, maxval=0.5)
return W1, b1, W2, b2
def ReLU(Z):
return jnp.maximum(Z, 0)
def softmax(Z):
A = jnp.exp(Z) / jnp.sum(jnp.exp(Z))
return A
def forward_prop(W1, b1, W2, b2, X):
Z1 = jnp.dot(W1, X) + b1
A1 = ReLU(Z1)
Z2 = jnp.dot(W2, A1) + b2
A2 = softmax(Z2)
return Z1, A1, Z2, A2
def ReLU_deriv(Z):
return Z > 0
def one_hot(Y):
one_hot_Y = jnp.zeros((Y.size, Y.max() + 1))
one_hot_Y = one_hot_Y.at[jnp.arange(Y.size), Y].set(1)
return one_hot_Y.T
def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):
one_hot_Y = one_hot(Y)
dZ2 = A2 - one_hot_Y
dW2 = 1 / m * jnp.dot(dZ2, A1.T)
db2 = 1 / m * jnp.sum(dZ2, axis=1, keepdims=True)
dZ1 = jnp.dot(W2.T, dZ2) * ReLU_deriv(Z1)
dW1 = 1 / m * jnp.dot(dZ1, X.T)
db1 = 1 / m * jnp.sum(dZ1, axis=1, keepdims=True)
return dW1, db1, dW2, db2
def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):
W1 = W1 - alpha * dW1
b1 = b1 - alpha * db1
W2 = W2 - alpha * dW2
b2 = b2 - alpha * db2
return W1, b1, W2, b2
def get_predictions(A2):
return jnp.argmax(A2, 0)
def get_accuracy(predictions, Y):
print(predictions, Y)
return jnp.sum(predictions == Y) / Y.size
def gradient_descent(X, Y, alpha, iterations):
W1, b1, W2, b2 = init_params()
for i in range(iterations):
Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)
dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)
W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)
if i % 10 == 0:
print("Iteration: ", i)
predictions = get_predictions(A2)
print(get_accuracy(predictions, Y))
return W1, b1, W2, b2
W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 1e-3, 500)
I'm a complete newbie to JAX and I cannot solve this. What is the issue and how to solve this ?
I tried to change the weight updating function as it seems to have some problem with updating the model weights
In your softmax function, you need to specify the axis
parameter in the jnp.sum
function to sum along the correct axis. Here's how:
def softmax(Z):
A = jnp.exp(Z) / jnp.sum(jnp.exp(Z), axis=0)
return A
By specifying axis=0
, you ensure that the exponential values are summed across the correct axis.