Search code examples
python-3.xpytorchloss-functionsoftmax

What is the Problem in my Building Softmax from Scratch in Pytorch


I read this post ans try to build softmax by myself. Here is the code

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
import numpy as np

#============================ get the dataset =========================

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

batch_size = 256
num_workers = 0  

train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)



#============================     train      =========================
num_inputs = 28 * 28
num_outputs = 10
epochs = 5
lr = 0.05

# Initi the Weight and bia
W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)
W.requires_grad_(requires_grad = True)
b.requires_grad_(requires_grad=True)

# softmax function
def softmax(X):
    X = X.exp()
    den = X.sum(dim=1, keepdim=True)
    return X / den  

# loss
def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1))).sum()

# accuracy function
def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()
    

for epoch in range(epochs):

    train_loss_sum = 0.0
    train_acc_sum = 0.0
    n_train = 0

    for X, y in train_iter:
        # X.shape: [256, 1, 28, 28]
        # y.shape: [256]
        
        # flatten the X into [256, 28*28]
        X = X.flatten(start_dim=1)  
        y_pred = softmax(torch.mm(X, W) + b)
        
        loss = cross_entropy(y_pred, y)
       
        loss.backward()

        W.data = W.data - lr * W.grad
        b.data = b.data - lr* b.grad

        W.grad.zero_()
        b.grad.zero_()

        train_loss_sum += loss.item() 

        train_acc_sum += accuracy(y_pred, y)
        n_train += y.shape[0]
    
    # evaluate the Test
   
    test_acc, n_test = 0.0, 0
    with torch.no_grad():

        for X_test, y_test in test_iter:
            X_test = X_test.flatten(start_dim=1) 
            y_test_pred = softmax(torch.mm(X_test, W) + b)
            test_acc += accuracy(y_test_pred, y_test)
            n_test += y_test.shape[0]

    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_loss_sum/n_train , train_acc_sum / n_train, test_acc / n_test))

Compare with original post, Here I turn

def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))

into

def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1))).sum()

Since the backward need a scalar.

However, My results are

epoch 1, loss nan, train acc 0.000, test acc 0.000
epoch 2, loss nan, train acc 0.000, test acc 0.000
epoch 3, loss nan, train acc 0.000, test acc 0.000
epoch 4, loss nan, train acc 0.000, test acc 0.000
epoch 5, loss nan, train acc 0.000, test acc 0.000

Any idea?

Thanks.


Solution

  • Change:

    def cross_entropy(y_hat, y):
        return - torch.log(y_hat.gather(1, y.view(-1, 1))).sum()
    

    To:

    def cross_entropy(y_hat, y):
        return - torch.log(y_hat[range(len(y_hat)), y] + 1e-8).sum()
    

    Outputs should be something like:

    epoch 1, loss 9.2651, train acc 0.002, test acc 0.002
    epoch 2, loss 7.8493, train acc 0.002, test acc 0.002
    epoch 3, loss 6.6875, train acc 0.002, test acc 0.003
    epoch 4, loss 6.0928, train acc 0.003, test acc 0.003
    epoch 5, loss 5.1277, train acc 0.003, test acc 0.003
    

    And be aware the problem of nan can also cause by X = X.exp() in the softmax(X), when X is too big then exp() will outputs inf, when this happen you could try to clip the X before using exp()