Search code examples
pythonmachine-learningdeep-learning

How to Implement Softmax, in python, whereby the input are signed 8 integers


I am trying to implement a softmax function that takes in signed int8 input and returns a signed int8 output array.

The current implementation I have going is this,

 import numpy as np

def softmax_int8(inputs):
    inputs = np.array(inputs, dtype=np.int8)
    
    x = inputs.astype(np.int32)
    x_max = np.max(x)
    x_shifted = x - x_max
    scale_factor = 2 ** 14 
    exp_limit = 16
    exp_x = np.clip(x_shifted + exp_limit, 0, None)
    exp_x = (1 << exp_x)
    sum_exp_x = np.sum(exp_x)

    if sum_exp_x == 0:
        sum_exp_x = 1

    softmax_probs = (exp_x * scale_factor) // sum_exp_x
    max_prob = np.max(softmax_probs)
    min_prob = np.min(softmax_probs)
    range_prob = max_prob - min_prob if max_prob != min_prob else 1

    scaled_probs = ((softmax_probs - min_prob) * 255) // range_prob - 128
    outputs = scaled_probs.astype(np.int8)

    return outputs

I test it using this input, Input = [101, 49, 6, -34, -75, -79, -38, 120, -55, 115]

but I get this output array([-128, -128, -128, -128, -128, -128, -128, 127, -128, -121],dtype=int8).

My expected output is array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=int8).

What am I doing wrong here and how can I fix it?


Solution

  • I think there are different mathematical definitions of softmax in different contexts.

    • Wikipedia definition (on real numbers): exp(z) / sum(exp(z))
    • What I inferred from your code: (1<<(z-z_max + 16)) / sum((1 << (z-z_max + 16))) or something similar. 1<< === 2** obviously.

    The major difference is the base number of the exponential. With base too high you are highly likely to get underflow and get a lot of -128. Besides there are also a biase that maps the result to [-128, 127] range, which is trival and less important

    It's highly likely that the library that you takes test cases from use a different definition than both of above.

    I did some testing with your test case and floating point definition of softmax with matplotlib, and the following expression gives a good fit:

    softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100
    

    You can imagine that you probably need to do a >>7 to input bytes before doing 1<< 2-based exponential. To give completely identical result, surely you should dig into that library code, which I didn't have time to do.

    Below are validation codes:

    import numpy as np
    import matplotlib.pyplot as plt
    
    inarr = np.array([101, 49, 6, -34, -75, -79, -38, 120, -55, 115], dtype=np.int8).astype(np.double)
    expected_arr = np.array([-57, -70, -79, -86, -92, -94, -88, -54, -91, -56], dtype=np.int8).astype(np.double)
    print(expected_arr)
    
    softmax_naive = (np.exp(inarr / 128) / np.sum(np.exp(inarr / 128)) * 256) - 100
    print(softmax_naive - expected_arr)
    plt.plot(inarr)
    plt.plot(expected_arr)
    plt.plot(softmax_naive)
    plt.show()
    

    validation of softmax