Search code examples
pythontensorflowmachine-learningmathgoogle-colaboratory

Tensorflow Value for attr 'TI' of float is not in the list of allowed values when One Hot Encoding


I have this code that takes a tensor with a shape of (3, 3) and reshapes it to (9,). After that it applies a one_hot function but it throws an error.

This is the code:

import tensorflow as tf

t1 = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32)
t2 = tf.constant([[1], [-1], [1]], dtype=tf.float32)

print(tf.one_hot(tf.reshape(t1, -1), depth=2))

And the error is :

InvalidArgumentError: Value for attr 'TI' of float is not in the list of allowed values: uint8, int32, int64
    ; NodeDef: {{node OneHot}}; Op<name=OneHot; signature=indices:TI, depth:int32, on_value:T, off_value:T -> output:T; attr=axis:int,default=-1; attr=T:type; attr=TI:type,default=DT_INT64,allowed=[DT_UINT8, DT_INT32, DT_INT64]> [Op:OneHot]

I'm working in a GoogleColab notebook, so I think that the problem might be the version of TensorFlow or the data types of the tensor, but any other solutions would be appreciated.


Solution

  • You could simply cast your tensor to tf.int32 or similar, since tf.one_hot expects integer indices:

    import tensorflow as tf
    
    t1 = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.float32)
    t2 = tf.constant([[1], [-1], [1]], dtype=tf.float32)
    
    print(tf.one_hot(tf.cast(tf.reshape(t1, -1), dtype=tf.int32), depth=3))
    
    tf.Tensor(
    [[0. 1. 0.]
     [1. 0. 0.]
     [1. 0. 0.]
     [1. 0. 0.]
     [0. 1. 0.]
     [1. 0. 0.]
     [1. 0. 0.]
     [1. 0. 0.]
     [0. 1. 0.]], shape=(9, 3), dtype=float32)
    

    Or with depth=2:

    tf.Tensor(
    [[0. 1.]
     [1. 0.]
     [1. 0.]
     [1. 0.]
     [0. 1.]
     [1. 0.]
     [1. 0.]
     [1. 0.]
     [0. 1.]], shape=(9, 2), dtype=float32)