Search code examples
pythonnumpyone-hot-encoding

One-hot encode along arbitrary dimension with NumPy


Given a numpy array with arbitrarily many dimensions, I would like to be able to one-hot encode any of these dimensions. For example, say I have an array a of shape (10, 20, 30, 40) I might want to one hot encode the second dimension, i.e. transform a such that the result only contains values 0 and 1 and a[i, :, j, k] contains exactly one zero entry for every choice of i, j and k (at the position of the previous maximum value along that dimension).

I thought about first obtaining a.argmax(axis=1) and then using np.ogrid to turn that into indices pointing to the maxima but I can't figure out the details. I'm also worried about memory consumption with this approach.

Is there an easy way to do this (ideally requiring little additional memory)?


Solution

  • Here's one way with array-assignment -

    def onehotencode_along_axis(a, axis):
        # Setup o/p hot encoded bool array 
        h = np.zeros(a.shape,dtype=bool)
        idx = a.argmax(axis=axis)
    
        # Setup same dimensional indexing array as the input
        idx = np.expand_dims(idx, axis) # Thanks to @Peter
    
        # Finally assign True values
        np.put_along_axis(h,idx,1,axis=axis)
        return h
    

    Sample runs on 2D case -

    In [109]: np.random.seed(0)
         ...: a = np.random.randint(11,99,(4,5))
    
    In [110]: a
    Out[110]: 
    array([[55, 58, 75, 78, 78],
           [20, 94, 32, 47, 98],
           [81, 23, 69, 76, 50],
           [98, 57, 92, 48, 36]])
    
    In [112]: onehotencode_along_axis(a, axis=0)
    Out[112]: 
    array([[False, False, False,  True, False],
           [False,  True, False, False,  True],
           [False, False, False, False, False],
           [ True, False,  True, False, False]])
    
    In [113]: onehotencode_along_axis(a, axis=1)
    Out[113]: 
    array([[False, False, False,  True, False],
           [False, False, False, False,  True],
           [ True, False, False, False, False],
           [ True, False, False, False, False]])
    

    Sample run for verification on higher (multidimensional) 5D case -

    In [114]: np.random.seed(0)
         ...: a = np.random.randint(11,99,(2,3,4,5,6))
         ...: for i in range(a.ndim):
         ...:     out = onehotencode_along_axis(a, axis=i)
         ...:     print np.allclose(out.sum(axis=i),1)
    True
    True
    True
    True
    True
    

    If you need the final output as an int array with 0s and 1s, use a view on the boolean output array :

    onehotencode_along_axis(a, axis=0).view('i1') and so on.