Search code examples
pythontensorflowtensorflow2.0tensorflow-datasets

how to construct object of the same type that `tf.one_hot` returns?


I have a function input_preprocess that I am using in the data pipeline:

def input_preprocess(image, label):
    if label == 1:
        return tf.zeros(NUM_CLASSES)
    else:
        label = tf.one_hot(label, NUM_CLASSES)
    return image, label

The problem is that whatever tf.one_hot returns is a sequence but what tf.zeros returns is not.

I get the following error:

 The two structures don't have the same nested structure.
    
    First structure: type=Tensor str=Tensor("cond/zeros_like:0", shape=(28,), dtype=float32)
    
    Second structure: type=tuple str=(<tf.Tensor 'args_0:0' shape=(224, 224, 3) dtype=float32>, <tf.Tensor 'cond/one_hot:0' shape=(28,) dtype=float32>)
    
    More specifically: Substructure "type=tuple str=(<tf.Tensor 'args_0:0' shape=(224, 224, 3) dtype=float32>, <tf.Tensor 'cond/one_hot:0' shape=(28,) dtype=float32>)" is a sequence, while substructure "type=Tensor str=Tensor("cond/zeros_like:0", shape=(28,), dtype=float32)" is not
    Entire first structure:
    .
    Entire second structure:
    (., .)

How can I manually construct something that could stand in for what tf.one_hot returns?


Solution

  • tf.one_hot returns an EagerTensor just like tf.zeros:

    a = tf.one_hot(1,2)
    print(type(a))
    b = tf.zeros(2)
    print(type(b))
    # tensorflow.python.framework.ops.EagerTensor
    # tensorflow.python.framework.ops.EagerTensor
    

    I think your issue is that your function input_preprocess returns a single value if label == 1 (first return) and a tuple otherwise (second return).