Search code examples
pythonnumpydeep-learningneural-networkjax

Building a Neural Network using JAX


I tried to build a Neural network from scratch using JAX numpy moldule. In the training phase it seems that the accuracy of the model doesn't improve at all. Here is the code.

import jax
import jax.numpy as jnp

def init_params():
    key = jax.random.PRNGKey(0)
    W1 = jax.random.uniform(key, (10, 784), minval=-0.5, maxval=0.5)
    b1 = jax.random.uniform(key, (10, 1), minval=-0.5, maxval=0.5)
    W2 = jax.random.uniform(key, (10, 10), minval=-0.5, maxval=0.5)
    b2 = jax.random.uniform(key, (10, 1), minval=-0.5, maxval=0.5)
    return W1, b1, W2, b2

def ReLU(Z):
    return jnp.maximum(Z, 0)

def softmax(Z):
    A = jnp.exp(Z) / jnp.sum(jnp.exp(Z))
    return A

def forward_prop(W1, b1, W2, b2, X):
    Z1 = jnp.dot(W1, X) + b1
    A1 = ReLU(Z1)
    Z2 = jnp.dot(W2, A1) + b2
    A2 = softmax(Z2)
    return Z1, A1, Z2, A2

def ReLU_deriv(Z):
    return Z > 0

def one_hot(Y):
    one_hot_Y = jnp.zeros((Y.size, Y.max() + 1))
    one_hot_Y = one_hot_Y.at[jnp.arange(Y.size), Y].set(1)
    return one_hot_Y.T


def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):
    one_hot_Y = one_hot(Y)
    dZ2 = A2 - one_hot_Y
    dW2 = 1 / m * jnp.dot(dZ2, A1.T)
    db2 = 1 / m * jnp.sum(dZ2, axis=1, keepdims=True)
    dZ1 = jnp.dot(W2.T, dZ2) * ReLU_deriv(Z1)
    dW1 = 1 / m * jnp.dot(dZ1, X.T)
    db1 = 1 / m * jnp.sum(dZ1, axis=1, keepdims=True)
    return dW1, db1, dW2, db2

def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):
    W1 = W1 - alpha * dW1
    b1 = b1 - alpha * db1    
    W2 = W2 - alpha * dW2  
    b2 = b2 - alpha * db2    
    return W1, b1, W2, b2


def get_predictions(A2):
    return jnp.argmax(A2, 0)

def get_accuracy(predictions, Y):
    print(predictions, Y)
    return jnp.sum(predictions == Y) / Y.size

def gradient_descent(X, Y, alpha, iterations):
    W1, b1, W2, b2 = init_params()
    for i in range(iterations):
        Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)
        dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)
        W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)
        if i % 10 == 0:
            print("Iteration: ", i)
            predictions = get_predictions(A2)
            print(get_accuracy(predictions, Y))
    return W1, b1, W2, b2


W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 1e-3, 500)

I'm a complete newbie to JAX and I cannot solve this. What is the issue and how to solve this ?

I tried to change the weight updating function as it seems to have some problem with updating the model weights


Solution

  • In your softmax function, you need to specify the axis parameter in the jnp.sum function to sum along the correct axis. Here's how:

    def softmax(Z):
        A = jnp.exp(Z) / jnp.sum(jnp.exp(Z), axis=0)
        return A
    

    By specifying axis=0, you ensure that the exponential values are summed across the correct axis.