Search code examples
pytorchjax

pytorch and jax networks give different accuracy with same settings


I have pytorch code which performs with more than 95% accuracy. The code essentially implements a feedforward neural network using PyTorch to classify the digits dataset. It trains the model using the Adam optimizer and computes the cross-entropy loss, and then evaluates the model's performance on the test set by calculating the accuracy.

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Load the digits dataset
digits = load_digits()

# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    digits.data, digits.target, test_size=0.2, random_state=42
)

# Scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Convert the data to PyTorch tensors
X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

# Define the FFN model
class FFN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(FFN, self).__init__()
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_sizes)):
            if i == 0:
                self.hidden_layers.append(nn.Linear(input_size, hidden_sizes[i]))
            else:
                self.hidden_layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
            self.hidden_layers.append(nn.ReLU())
        self.output_layer = nn.Linear(hidden_sizes[-1], output_size)

    def forward(self, x):
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

# Define the training parameters
input_size = X_train.shape[1]
hidden_sizes = [64, 32]  # Modify the hidden layer sizes as per your requirement
output_size = len(torch.unique(y_train_tensor))
learning_rate = 0.001
num_epochs = 200
batch_size = len(X_train)  # Set batch size to the size of the training dataset

# Create the FFN model
model = FFN(input_size, hidden_sizes, output_size)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_train_tensor)
    loss = criterion(outputs, y_train_tensor)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# Evaluate the model on the test set
with torch.no_grad():
    model.eval()
    outputs = model(X_test_tensor)
    _, predicted = torch.max(outputs.data, 1)
    for j in range(len(predicted)):
        print(predicted[j], y_test_tensor[j])
    accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0) * 100
    print(f"Test Accuracy: {accuracy:.2f}%")

Also I have the equivalent jax code, with performs with less than 10% of accuracy

import jax
import jax.numpy as jnp
from jax import grad, jit, random, value_and_grad
from jax.scipy.special import logsumexp
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from jax.example_libraries.optimizers import adam, momentum, sgd, nesterov, adagrad, rmsprop
from jax import nn as jnn


# Load the digits dataset
digits = load_digits()

# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)

# Reshape the target variables
y_train_reshaped = jnp.reshape(y_train, (-1, 1))
y_test_reshaped = jnp.reshape(y_test, (-1, 1))

X_train_reshaped = jnp.reshape(X_train, (-1, 1))
X_test_reshaped = jnp.reshape(X_test, (-1, 1))
#print(np.shape(X_train),np.shape(y_train_reshaped),np.shape(y_train))

# Scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_reshaped)
y_test_scaled = scaler.transform(y_test_reshaped)

# Convert the data to JAX arrays
X_train_array = jnp.array(X_train, dtype=jnp.float32)
y_train_array = jnp.array(y_train_reshaped, dtype=jnp.int32)
X_test_array = jnp.array(X_test, dtype=jnp.float32)
y_test_array = jnp.array(y_test_reshaped, dtype=jnp.int32)

# Define the FFN model
def init_params(rng_key):
    sizes = [X_train_array.shape[1]] + hidden_sizes + [output_size]
    keys = random.split(rng_key, len(sizes))
    params = []
    for i in range(1, len(sizes)):
        params.append((random.normal(keys[i], (sizes[i-1], sizes[i])), 
                       random.normal(keys[i], (sizes[i],))))
    return params

def forward(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    w, b = params[-1]
    x = jnp.dot(x, w) + b
    return x

def softmax(logits):
    logsumexp_logits = logsumexp(logits, axis=1, keepdims=True)
    return jnp.exp(logits - logsumexp_logits)

def cross_entropy_loss(logits, labels):
    log_probs = logits - logsumexp(logits, axis=1, keepdims=True)
    return -jnp.mean(jnp.sum(log_probs * labels, axis=1))

# Define the training parameters
input_size = X_train_array.shape[1]
hidden_sizes = [64, 32]  # Modify the hidden layer sizes as per your requirement
output_size = len(jnp.unique(y_train_array))
learning_rate = 0.001
num_epochs = 200
batch_size = len(X_train_array)  # Set batch size to the size of the training dataset
# Create the FFN model
rng_key = random.PRNGKey(0)
params = init_params(rng_key)

# Define the loss function
def loss_fn(params, x, y):
    logits = forward(params, x)
    probs = softmax(logits)
    labels = jax.nn.one_hot(y, output_size)
    return cross_entropy_loss(logits, labels)

# Create the optimizer
opt_init, opt_update, get_params = adam(learning_rate)
opt_state = opt_init(params)

# Define the update step
@jit
def update(params, x, y, opt_state):
    grads = grad(loss_fn)(params, x, y)
    return opt_update(0, grads, opt_state)

# Train the model
for epoch in range(num_epochs):
    perm = random.permutation(rng_key, len(X_train_array))
    for i in range(0, len(X_train_array), batch_size):
        batch_idx = perm[i:i+batch_size]
        X_batch = X_train_array[batch_idx]
        y_batch = y_train_array[batch_idx]
        params = get_params(opt_state)
        opt_state = update(params, X_batch, y_batch, opt_state)

    if (epoch + 1) % 10 == 0:
        params = get_params(opt_state)
        loss = loss_fn(params, X_train_array, y_train_array)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}")

# Evaluate the model on the test set
params = get_params(opt_state)
logits = forward(params, X_test_array)
predicted = jnp.argmax(logits, axis=1)

for j in range(len(predicted)):
    print(predicted[j], y_test_array[j])

accuracy = jnp.mean(predicted == y_test_array) * 100
print(f"Test Accuracy: {accuracy:.2f}%")

I dont understand why the jax code performs poorly. Could you please help me in underding the bug in the jax code.


Solution

  • There are 2 probles in your jax code that are, actually, in data processing:

    1. Your data are not scaled. If you look at your X_train_array definition, it is the jax version of X_train, that is the raw data. Please consider using:
    # Scale the features
    scaler = StandardScaler().fit(X_train)  # No need to flat it!
    X_train = scaler.transform(X_train)
    X_test = scaler.transform(X_test)
    
    # Convert the data to JAX arrays
    X_train_array = jnp.array(X_train, dtype=jnp.float32)
    y_train_array = jnp.array(y_train_reshaped, dtype=jnp.int32)
    X_test_array = jnp.array(X_test, dtype=jnp.float32)
    y_test_array = jnp.array(y_test_reshaped, dtype=jnp.int32)
    
    1. Your labels are of shape (N, 1) before one-hot encoding. After one-hot encoding it is (N, 1, n_out) while your predictions are of shape (N, n_out) so when you make your loss computation the two arrays are cast in (N, n_out, n_out) with repetitions and your loss is wrong. You can solve it very simply by remove the 1 in the reshape:
    # Reshape the target variables
    y_train_reshaped = jnp.reshape(y_train, (-1,))
    y_test_reshaped = jnp.reshape(y_test, (-1,))
    

    I tested your code with 300 epochs and lr=0.01 and I got an accuracy of 90% in test (and the loss decreased to 0.0001)