Search code examples
pythontensorflowtensorflow-datasetsimage-rotationdata-augmentation

"args_0:0" when printing Tensor


I have an augmentation function that is being mapped into a generator; however, for some reason, the tfa.image.rotate function is causing an error.

def customGenerator(input_file_paths, dims, data_type):
    for i, file_path in enumerate(input_file_paths):
        if data_type.decode("utf-8") in ["png" or "tif"]:
            img = plt.imread((file_path.decode("utf-8")))
        elif data_type.decode("utf-8") == "npy":
            img = np.load(file_path.decode("utf-8"))
        x = resize(img[:,:,:3], dims)           
        yield x, x

def augment(image,label) :
    print('image', image)
    print('shape', image.shape)
    print('type', type(image))

    #angle = random.uniform(0, tf.constant(np.pi))
    image = tfa.image.rotate(image, tf.constant(np.pi)


train_dataset = tf.data.Dataset.from_generator(generator=customGenerator, 
                                                 output_types=(np.float32, np.float32), 
                                                 output_shapes=(dims, dims), 
                                                 args=[X_train_paths, dims, "png"])

train_dataset = train_dataset.map(augment, num_parallel_calls=AUTOTUNE)

I looked at the implementation of tfa.image.rotate other people had used, and theirs was working fine. I tried printing the image variable in the augment function. And this resulted:

print('image', image) # these lines is in the augment function, result below
print('type', type(image))
image Tensor("args_0:0", shape=(256, 256, 3), dtype=float32)
type <class 'tensorflow.python.framework.ops.Tensor'>

In contrast, when I go to other user's implementations and print their image, which is not mapped into a dataset. Their print(image) and print(type(image)) prints out this:

image tf.Tensor(
[[[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  ...
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]

 ...

 [[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  ...
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]], shape=(256, 256, 3), dtype=float32)

type <class 'tensorflow.python.framework.ops.EagerTensor'>

I expected this to be printing when I printed image in the augment function. So I am unsure what is going one. So a couple of questions. Also, tf.executing_eagerly() results in True

What exactly does "args_0:0" mean?

Should the image in the augment function be of type <class 'tensorflow.python.framework.ops.EagerTensor'> instead of a normal tensor?

Is there some way I can make "args_0:0" into the format I expected to have where it prints the array of numbers? Cause I believe that this will fix the rotation function

Finally, if not, is there a better way to augment the image with a random rotation?

Thank you for your time and help.


Solution

  • args_0:0 is tensor. See here

    I made some changes to your code in order to make it work.

    Code:

    import tensorflow_addons as tfa
    import os
    import matplotlib.pyplot as plt
    import tensorflow as tf
    import numpy as np
    
    
    def customGenerator(input_file_paths, dims, data_type):
        for i, file_path in enumerate(input_file_paths):
            image = tf.io.read_file(file_path)
            image = tf.image.decode_png(image, channels = 3)
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.resize(image, [dims[0],dims[1]])
            yield image, image
    
    def augment(image,label) :
        img = tfa.image.rotate(image, tf.constant(np.pi/8))   
        return (img, label)
    
    X_train_paths = [os.path.join('data','img',name) for name in os.listdir('data/img')]
    dims = (256,256,3)
    
    train_dataset = tf.data.Dataset.from_generator(generator=customGenerator, 
                                                     output_types=(tf.float32, tf.float32), 
                                                     output_shapes=(dims, dims), 
                                                     args=[X_train_paths, dims, "png"])
    
    train_dataset = train_dataset.map(augment)
    

    Iterating over the dataset:

    for images in train_dataset:
        rotatedimg, normalimg= images[0],images[1]
        break
    

    Output:

    plt.imshow(rotatedimg)
    

    enter image description here

    plt.imshow(normalimg)
    

    enter image description here

    Things to remember:

    1. The map function is not executed eagerly.
    2. Always use TensorFlow function inside your generator function and map function because TensorFlow executes these functions as a part of the graph in order to speed up its execution. See here
    3. If you use other functions TensorFlow might not be able to convert these operations into graphs which will result in an error.