Search code examples
pythontensorflowkerastensorflow2.0tf.data.dataset

Passing tensors as an argument to a function


i am trying to normalize a tf.data.Dataset as seen below:

def normalization(image):
    print(image['label'])
    
    return 1
    

z = val.map(normalization) 

the val dataset is like this :

<TakeDataset shapes: { id: (), image: (32, 32, 3), label: ()}, types: {id: tf.string, image: tf.uint8, label: tf.int64}>

and if i print one element i can see :

  { 'id': <tf.Tensor: shape=(), dtype=string, numpy=b'train_31598'>, 'image': <tf.Tensor: shape=(32, 32, 3), dtype=uint8, 
 numpy=    array([[[151, 130, 106],
            .....,
            [104,  95,  77]]], dtype=uint8)>, 'label': <tf.Tensor: shape=(), dtype=int64, numpy=50>}

However printing this inside my function outputs :

 'id': <tf.Tensor 'args_1:0' shape=() dtype=string>, 'image': <tf.Tensor 'args_2:0' shape=(32, 32, 3) dtype=uint8>, 'label': <tf.Tensor 'args_3:0' shape=() dtype=int64>}

so i can't perform any transformation to my image array because instead of a tensor array i have 'args_2:0'

How can i pass each element correctly to my normalization fucntion?


Solution

  • I tried your code on a standard dataset and it wasn't working. image['label'] isn't correct because you should give an integer. Here is my modification to your code:

    def normalization(image,label):
    print(image[0])
    
    return tf.cast(image, tf.float32) / 255., label
    
    
    z = ds_train.map(normalization)