Search code examples
pythonnumpyone-hot-encoding

np.argmax doesn't return integer


I have some onehot encoded data called testoutput, which has shape (1000,14).

I want to decode it, so following some advice I found online I used the following code:

# go from onehot encoding to integer
def decode(datum):
    return np.argmax(datum)

predictedhits=np.empty((1000))
for i in range(testoutput.shape[0]):
    datum = testoutput[i]
    predictedhits[i] = decode(datum)
    print('Event',i,'predicted number of hits: %s' % predictedhits[i])

The problem is that I wanted and expected np.argmax to output an integer, but instead it outputs a numpy.float64. Please can someone tell me why this happens and what do do about it? Simply doing predictedhits[i] = int(decode(datum)) doesn't change anything.

Thanks in advance!


Solution

  • You've misdiagnosed the problem. numpy.argmax is not returning an instance of numpy.float64. Rather, predictedhits has float64 dtype. Storing any value into that array stores it as a 64-bit float, and retrieving predictedhits[i] from the array produces a numpy.float64 object.

    Instead of iterating over rows of testoutput one at a time and storing values into an empty array one by one, just call argmax along the axis you want:

    predictedhits = np.argmax(testoutput, axis=1)
    

    This saves code, saves runtime, and produces an array of the proper dtype.