I have pytorch code which performs with more than 95% accuracy. The code essentially implements a feedforward neural network using PyTorch to classify the digits dataset. It trains the model using the Adam optimizer and computes the cross-entropy loss, and then evaluates the model's performance on the test set by calculating the accuracy.
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# Load the digits dataset
digits = load_digits()
# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
digits.data, digits.target, test_size=0.2, random_state=42
)
# Scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Convert the data to PyTorch tensors
X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)
# Define the FFN model
class FFN(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size):
super(FFN, self).__init__()
self.hidden_layers = nn.ModuleList()
for i in range(len(hidden_sizes)):
if i == 0:
self.hidden_layers.append(nn.Linear(input_size, hidden_sizes[i]))
else:
self.hidden_layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
self.hidden_layers.append(nn.ReLU())
self.output_layer = nn.Linear(hidden_sizes[-1], output_size)
def forward(self, x):
for layer in self.hidden_layers:
x = layer(x)
x = self.output_layer(x)
return x
# Define the training parameters
input_size = X_train.shape[1]
hidden_sizes = [64, 32] # Modify the hidden layer sizes as per your requirement
output_size = len(torch.unique(y_train_tensor))
learning_rate = 0.001
num_epochs = 200
batch_size = len(X_train) # Set batch size to the size of the training dataset
# Create the FFN model
model = FFN(input_size, hidden_sizes, output_size)
# Define the loss function
criterion = nn.CrossEntropyLoss()
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Train the model
for epoch in range(num_epochs):
# Forward pass
outputs = model(X_train_tensor)
loss = criterion(outputs, y_train_tensor)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
# Evaluate the model on the test set
with torch.no_grad():
model.eval()
outputs = model(X_test_tensor)
_, predicted = torch.max(outputs.data, 1)
for j in range(len(predicted)):
print(predicted[j], y_test_tensor[j])
accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0) * 100
print(f"Test Accuracy: {accuracy:.2f}%")
Also I have the equivalent jax code, with performs with less than 10% of accuracy
import jax
import jax.numpy as jnp
from jax import grad, jit, random, value_and_grad
from jax.scipy.special import logsumexp
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from jax.example_libraries.optimizers import adam, momentum, sgd, nesterov, adagrad, rmsprop
from jax import nn as jnn
# Load the digits dataset
digits = load_digits()
# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)
# Reshape the target variables
y_train_reshaped = jnp.reshape(y_train, (-1, 1))
y_test_reshaped = jnp.reshape(y_test, (-1, 1))
X_train_reshaped = jnp.reshape(X_train, (-1, 1))
X_test_reshaped = jnp.reshape(X_test, (-1, 1))
#print(np.shape(X_train),np.shape(y_train_reshaped),np.shape(y_train))
# Scale the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_reshaped)
y_test_scaled = scaler.transform(y_test_reshaped)
# Convert the data to JAX arrays
X_train_array = jnp.array(X_train, dtype=jnp.float32)
y_train_array = jnp.array(y_train_reshaped, dtype=jnp.int32)
X_test_array = jnp.array(X_test, dtype=jnp.float32)
y_test_array = jnp.array(y_test_reshaped, dtype=jnp.int32)
# Define the FFN model
def init_params(rng_key):
sizes = [X_train_array.shape[1]] + hidden_sizes + [output_size]
keys = random.split(rng_key, len(sizes))
params = []
for i in range(1, len(sizes)):
params.append((random.normal(keys[i], (sizes[i-1], sizes[i])),
random.normal(keys[i], (sizes[i],))))
return params
def forward(params, x):
for w, b in params[:-1]:
x = jnp.dot(x, w) + b
x = jax.nn.relu(x)
w, b = params[-1]
x = jnp.dot(x, w) + b
return x
def softmax(logits):
logsumexp_logits = logsumexp(logits, axis=1, keepdims=True)
return jnp.exp(logits - logsumexp_logits)
def cross_entropy_loss(logits, labels):
log_probs = logits - logsumexp(logits, axis=1, keepdims=True)
return -jnp.mean(jnp.sum(log_probs * labels, axis=1))
# Define the training parameters
input_size = X_train_array.shape[1]
hidden_sizes = [64, 32] # Modify the hidden layer sizes as per your requirement
output_size = len(jnp.unique(y_train_array))
learning_rate = 0.001
num_epochs = 200
batch_size = len(X_train_array) # Set batch size to the size of the training dataset
# Create the FFN model
rng_key = random.PRNGKey(0)
params = init_params(rng_key)
# Define the loss function
def loss_fn(params, x, y):
logits = forward(params, x)
probs = softmax(logits)
labels = jax.nn.one_hot(y, output_size)
return cross_entropy_loss(logits, labels)
# Create the optimizer
opt_init, opt_update, get_params = adam(learning_rate)
opt_state = opt_init(params)
# Define the update step
@jit
def update(params, x, y, opt_state):
grads = grad(loss_fn)(params, x, y)
return opt_update(0, grads, opt_state)
# Train the model
for epoch in range(num_epochs):
perm = random.permutation(rng_key, len(X_train_array))
for i in range(0, len(X_train_array), batch_size):
batch_idx = perm[i:i+batch_size]
X_batch = X_train_array[batch_idx]
y_batch = y_train_array[batch_idx]
params = get_params(opt_state)
opt_state = update(params, X_batch, y_batch, opt_state)
if (epoch + 1) % 10 == 0:
params = get_params(opt_state)
loss = loss_fn(params, X_train_array, y_train_array)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}")
# Evaluate the model on the test set
params = get_params(opt_state)
logits = forward(params, X_test_array)
predicted = jnp.argmax(logits, axis=1)
for j in range(len(predicted)):
print(predicted[j], y_test_array[j])
accuracy = jnp.mean(predicted == y_test_array) * 100
print(f"Test Accuracy: {accuracy:.2f}%")
I dont understand why the jax code performs poorly. Could you please help me in underding the bug in the jax code.
There are 2 probles in your jax code that are, actually, in data processing:
X_train_array
definition, it is the jax version of X_train
, that is the raw data.
Please consider using:# Scale the features
scaler = StandardScaler().fit(X_train) # No need to flat it!
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
# Convert the data to JAX arrays
X_train_array = jnp.array(X_train, dtype=jnp.float32)
y_train_array = jnp.array(y_train_reshaped, dtype=jnp.int32)
X_test_array = jnp.array(X_test, dtype=jnp.float32)
y_test_array = jnp.array(y_test_reshaped, dtype=jnp.int32)
# Reshape the target variables
y_train_reshaped = jnp.reshape(y_train, (-1,))
y_test_reshaped = jnp.reshape(y_test, (-1,))
I tested your code with 300 epochs and lr=0.01 and I got an accuracy of 90% in test (and the loss decreased to 0.0001)