Search code examples
pythonmachine-learningpytorchscipynumerical-methods

How is log_softmax() implemented to compute its value (and gradient) with better speed and numerical stability?


Various frameworks and libraries (such as PyTorch and SciPy) provide special implementation for computing log(softmax()), which is faster and numerically more stable. However, I cannot find the actual Python implementation for this function, log_softmax(), in either of these packages.

Can anyone explain how this is implemented, or better, point me to the relevant source code?


Solution

    • The numerical error:
    >>> x = np.array([1, -10, 1000])
    >>> np.exp(x) / np.exp(x).sum()
    RuntimeWarning: overflow encountered in exp
    RuntimeWarning: invalid value encountered in true_divide
    Out[4]: array([ 0.,  0., nan])
    

    There are 2 methods to avoid the numerical error while compute the softmax:

    • Exp Normalization:

    enter image description here

    def exp_normalize(x):
        b = x.max()
        y = np.exp(x - b)
        return y / y.sum()
    
    >>> exp_normalize(x)
    array([0., 0., 1.])
    
    • Log Sum Exp

    enter image description here

    def log_softmax(x):
        c = x.max()
        logsumexp = np.log(np.exp(x - c).sum())
        return x - c - logsumexp
    
    

    Please note that, a reasonable choice for both b, c in above formula is max(x). With this choice, overflow due to exp is impossible. The largest number exponentiated after shifting is 0.