Search code examples
pythonfunctionpytorchsoftmax

Adapting pytorch softmax function


I am currently looking into the softmax function and I would like to adapt the orignally implemented for ome small tests.

I have been to the docs but there wasn't that much of usefull information about the function. This is the pytorch python implementation:

def __init__(self, dim=None):
    super(Softmax, self).__init__()
    self.dim = dim

def __setstate__(self, state):
    self.__dict__.update(state)
    if not hasattr(self, 'dim'):
        self.dim = None

def forward(self, input):
    return F.softmax(input, self.dim, _stacklevel=5)

Where can I find the F.softmax impementation?

One off the things I want to try for instance is the soft-margin softmax described here: Soft-Margin Softmax for Deep Classification

Where would be the best place to start? Thanks in advance!


Solution

  • Softmax Implementation in PyTorch and Numpy

    A Softmax function is defined as follows:

    Softmax function definition

    A direct implementation of the above formula is as follows:

    def softmax(x):
        return np.exp(x) / np.exp(x).sum(axis=0)
    

    Above implementation can run into arithmetic overflow because of np.exp(x).

    To avoid the overflow, we can divide the numerator and denominator in the softmax equation with a constant C. Then the softmax function becomes following:

    Softmax with numerator and denominator divided with constant C

    The above approach is implemented in PyTorch and we take log(C) as -max(x). Below is the PyTorch implementation:

    def softmax_torch(x): # Assuming x has atleast 2 dimensions
        maxes = torch.max(x, 1, keepdim=True)[0]
        x_exp = torch.exp(x-maxes)
        x_exp_sum = torch.sum(x_exp, 1, keepdim=True)
        probs = x_exp/x_exp_sum
        return probs 
    

    A corresponding Numpy equivalent is as follows:

    def softmax_np(x):
        maxes = np.max(x, axis=1, keepdims=True)[0]
        x_exp = np.exp(x-maxes)
        x_exp_sum = np.sum(x_exp, 1, keepdims=True)
        probs = x_exp/x_exp_sum
        return probs 
    

    We can compare the results with PyTorch implementation - torch.nn.functional.softmax using below snippet:

    import torch
    import numpy as np
    if __name__ == "__main__":
        x = torch.randn(1, 3, 5, 10)
        std_pytorch_softmax = torch.nn.functional.softmax(x)
        pytorch_impl = softmax_torch(x)
        numpy_impl = softmax_np(x.detach().cpu().numpy())
        print("Shapes: x --> {}, std --> {}, pytorch impl --> {}, numpy impl --> {}".format(x.shape, std_pytorch_softmax.shape, pytorch_impl.shape, numpy_impl.shape))
        print("Std and torch implementation are same?", torch.allclose(std_pytorch_softmax, pytorch_impl))
        print("Std and numpy implementation are same?", torch.allclose(std_pytorch_softmax, torch.from_numpy(numpy_impl)))
    
    References: