Given a numpy array with arbitrarily many dimensions, I would like to be able to one-hot encode any of these dimensions. For example, say I have an array a
of shape (10, 20, 30, 40)
I might want to one hot encode the second dimension, i.e. transform a
such that the result only contains values 0
and 1
and a[i, :, j, k]
contains exactly one zero entry for every choice of i
, j
and k
(at the position of the previous maximum value along that dimension).
I thought about first obtaining a.argmax(axis=1)
and then using np.ogrid
to turn that into indices pointing to the maxima but I can't figure out the details. I'm also worried about memory consumption with this approach.
Is there an easy way to do this (ideally requiring little additional memory)?
Here's one way with array-assignment
-
def onehotencode_along_axis(a, axis):
# Setup o/p hot encoded bool array
h = np.zeros(a.shape,dtype=bool)
idx = a.argmax(axis=axis)
# Setup same dimensional indexing array as the input
idx = np.expand_dims(idx, axis) # Thanks to @Peter
# Finally assign True values
np.put_along_axis(h,idx,1,axis=axis)
return h
Sample runs on 2D
case -
In [109]: np.random.seed(0)
...: a = np.random.randint(11,99,(4,5))
In [110]: a
Out[110]:
array([[55, 58, 75, 78, 78],
[20, 94, 32, 47, 98],
[81, 23, 69, 76, 50],
[98, 57, 92, 48, 36]])
In [112]: onehotencode_along_axis(a, axis=0)
Out[112]:
array([[False, False, False, True, False],
[False, True, False, False, True],
[False, False, False, False, False],
[ True, False, True, False, False]])
In [113]: onehotencode_along_axis(a, axis=1)
Out[113]:
array([[False, False, False, True, False],
[False, False, False, False, True],
[ True, False, False, False, False],
[ True, False, False, False, False]])
Sample run for verification on higher (multidimensional) 5D
case -
In [114]: np.random.seed(0)
...: a = np.random.randint(11,99,(2,3,4,5,6))
...: for i in range(a.ndim):
...: out = onehotencode_along_axis(a, axis=i)
...: print np.allclose(out.sum(axis=i),1)
True
True
True
True
True
If you need the final output as an int
array with 0
s and 1
s, use a view on the boolean output array :
onehotencode_along_axis(a, axis=0).view('i1')
and so on.