Search code examples
deep-learninglogistic-regressionequinoxjax

Failing to implement logistic regression using 'equinox' and 'optax' library


I am trying to implement logistic regression using equinox and optax libraries, with the support of JAX. While training the model, the loss is not decreasing over time,and model is not learning. Herewith attaching a reproducible code with toy dataset for reference:

import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx
import optax

data_key,model_key = jax.random.split(jax.random.PRNGKey(0),2)

### Generating toy-data

X_train = jax.random.normal(data_key, (1000,2))
y_train = X_train[:,0]+X_train[:,1]
y_train = jnp.where(y_train>0.5,1,0)

### Using equinox and optax
print("Training using equinox and optax")

epochs = 10000             
learning_rate = 0.1
n_inputs = X_train.shape[1]

class Logistic_Regression(eqx.Module):
    weight: jax.Array
    bias: jax.Array
    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))
        #self.weight = jnp.zeros((out_size, in_size))
        #self.bias = jnp.zeros((out_size,))
    def __call__(self, x):
        return jax.nn.sigmoid(self.weight @ x + self.bias)

@eqx.filter_value_and_grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x) 
    return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

@eqx.filter_jit
def make_step(model, x, y, opt_state):
    loss, grads = loss_fn(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

in_size, out_size = n_inputs, 1
model = Logistic_Regression(in_size, out_size, key=model_key)
optim = optax.sgd(learning_rate)
opt_state = optim.init(model)
for epoch in range(epochs):
    loss, model, opt_state = make_step(model,X_train,y_train, opt_state)
    loss = loss.item()
    if (epoch+1)%1000 ==0:
        print(f"loss at epoch {epoch+1}:{loss}")

# The following code is implementation of Logistic regression using scikit-learn and pytorch, and it is working well. It is added just for reference


### Using scikit-learn
print("Training using scikit-learn")
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
model = LogisticRegression()
model.fit(X_train,y_train)
y_pred = model.predict(X_train)
print("Train accuracy:",accuracy_score(y_train,y_pred))

## Using pytorch
print("Training using pytorch")
import numpy as np
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.nn import Sequential

X_train = np.array(X_train)
y_train = np.array(y_train)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:",device)
torch_LR= Sequential(nn.Linear(n_inputs, 1),
                nn.Sigmoid())
torch_LR.to(device)
criterion = nn.BCELoss() # define the optimization
optimizer = SGD(torch_LR.parameters(), lr=learning_rate)

train_loss = []
for epoch in range(epochs):
    inputs, targets = torch.tensor(X_train).to(device), torch.tensor(y_train).to(device) # move the data to GPU if available
    optimizer.zero_grad() # clear the gradients
    yhat = torch_LR(inputs.float()) # compute the model output
    loss = criterion(yhat, targets.unsqueeze(1).float()) # calculate loss
    #train_loss_batch.append(loss.cpu().detach().numpy()) # store the loss
    loss.backward() # update model weights
    optimizer.step()
    if (epoch+1)%1000 ==0:
        print(f"loss at epoch {epoch+1}:{loss.cpu().detach().numpy()}")


I tried SGD and adam optmizers with different learning rates, but the result is same. Also, I tried zero weight initialisation and ranodom weight initialisation. For the same data, I tried pytorch and LogisticRegression module from scikit-learn library (I understood in sklearn SGD is not used, but just as a reference to observe performance). Scikit-learn and pytorch modeling is added in the code block for reference. I have tried this with multiple classification datasets but still facing this problem.


Solution

  • The first time you print your loss is after 1000 epochs. If you change it to print the loss of the first 10 epochs, you see that the optimizer is rapidly converging:

        # ...
        if epoch < 10 or (epoch + 1)%1000 ==0:
            print(f"loss at epoch {epoch+1}:{loss}")
    

    Here is the result:

    Training using equinox and optax
    loss at epoch 1:1.237254023551941
    loss at epoch 2:1.216030478477478
    loss at epoch 3:1.1952687501907349
    loss at epoch 4:1.174972414970398
    loss at epoch 5:1.1551438570022583
    loss at epoch 6:1.1357849836349487
    loss at epoch 7:1.1168975830078125
    loss at epoch 8:1.098482370376587
    loss at epoch 9:1.0805412530899048
    loss at epoch 10:1.0630732774734497
    loss at epoch 1000:0.6320337057113647
    loss at epoch 2000:0.6320337057113647
    loss at epoch 3000:0.6320337057113647
    

    By epoch 1000, the loss has converged to a minimum value from which it does not move.

    Given this, it looks like your optimizer is functioning correctly.


    Edit: I did some debugging and found that y_pred = jax.vmap(model)(X_train) returns an array of shape (1000, 1), so (y - y_pred) is not a length-1000 array of differences, but rather a shape (1000, 1000) array of pairwise differences between all outputs. The log-loss over these pairwise differences is not a standard logistic regression model.