Search code examples
machine-learningmathdeep-learningsoftmax

is any method to approximate the softmax probability under special conditions?


I'm trying to find approach to compute the softmax probability without using exp().

assume that:

target: to compute f(x1, x2, x3) = exp(x1)/[exp(x1)+exp(x2)+exp(x3)]

conditions:

    1. -64 < x1,x2,x3 < 64

    2. result is just kept 3 desimal places.

is there any way to find a polynomial to approximately represent the result under such conditions?


Solution

  • Since the activation range is often vastly larger than the domain of exp(x), one will mostly need to find the largest activation value m = max(a,b,c) then subtract that from all the values.

    This is identical to 1 / (1 + exp(b-m) + exp(c-m)), with a selected/sorted to be the largest of the values.

    The number of exp functions is thus reduced a bit, however it's possible that the sorting is actually more costly than the fastest exp approximations:


    For the exp function there is also a well known 1st order approximation available -- which is of the form (int)(x * 12102203.2f) + 127 * (1 << 23) - 486411 reinterpreted as float -- see Fastest Implementation of the Natural Exponential Function Using SSE


    I just recently found another method, which lacks a little bit of accuracy, but parallelises better on selected SIMD implementation (Arm64) without using float <-> int conversions:

       template <typename T>
       T fastExp2(T x) {
            if constexpr(sizeof(x) == 2) {
                // 0 10101 0 01111 xxxx // just 4 bits of fractionals
                x += (T)79.0f;
                return std::bit_cast<T>(std::bit_cast<uint16_t>(x) << 6);
            } else if constexpr(sizeof(x) == 4) {
                // 0 10001000 001111111 xxxxx xxxxx xxxx // 14 fractional bits
                x += (T)639.0f;
                return std::bit_cast<T>(std::bit_cast<uint32_t>(x) << 9);   
            }
            // 0 10000001011 001111111111 xxxx... // 40 fractional bits
            x += (T)5119.0f;
            return std::bit_cast<T>(std::bit_cast<uint64_t>(x) << 12);
       }
    

    If it's not obvious, what's happening here, it's that the argument x is shifted or offset by a large (carefully selected) integer. Some or most of the fractional bits stay intact, where as the integer part will be added to an exponent bias. At this point there's the correct (but truncated) result embedded in the floating point number, which just needs to be shifted to the correct position.

    One can premultiply the weights of the last convolutional layer by log2(e) == 1.44269504088896 to avoid the scaling in the exponential function.