I have this code that takes a tensor with a shape of (3, 3)
and reshapes it to (9,)
. After that it applies a one_hot
function but it throws an error.
This is the code:
import tensorflow as tf
t1 = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32)
t2 = tf.constant([[1], [-1], [1]], dtype=tf.float32)
print(tf.one_hot(tf.reshape(t1, -1), depth=2))
And the error is :
InvalidArgumentError: Value for attr 'TI' of float is not in the list of allowed values: uint8, int32, int64
; NodeDef: {{node OneHot}}; Op<name=OneHot; signature=indices:TI, depth:int32, on_value:T, off_value:T -> output:T; attr=axis:int,default=-1; attr=T:type; attr=TI:type,default=DT_INT64,allowed=[DT_UINT8, DT_INT32, DT_INT64]> [Op:OneHot]
I'm working in a GoogleColab notebook, so I think that the problem might be the version of TensorFlow or the data types of the tensor, but any other solutions would be appreciated.
You could simply cast your tensor to tf.int32
or similar, since tf.one_hot
expects integer indices
:
import tensorflow as tf
t1 = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32)
t2 = tf.constant([[1], [-1], [1]], dtype=tf.float32)
print(tf.one_hot(tf.cast(tf.reshape(t1, -1), dtype=tf.int32), depth=3))
tf.Tensor(
[[0. 1. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]
[1. 0. 0.]
[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]], shape=(9, 3), dtype=float32)
Or with depth=2
:
tf.Tensor(
[[0. 1.]
[1. 0.]
[1. 0.]
[1. 0.]
[0. 1.]
[1. 0.]
[1. 0.]
[1. 0.]
[0. 1.]], shape=(9, 2), dtype=float32)