Search code examples
pythontensorflowtensorflow-datasets

About tensorflow map function with ifelse statement


I would like to cut the image which are stored as tensorflow dataset to squared image. But it seems that the tensorflow could not allow the map function with ifelse statement. I hope to know whether I could solve this problem. Many thanks in advance.

 def tf_resize(img, new_size=256):
 
    h, w,__ = img.shape
    start = math.ceil(abs(w-h))  
    img_corp = tf.cond(tf.constant(w>h, dtype=tf.bool), 
                       lambda: img[:, start:(start+h), :], 
                       lambda: img[start:(start+w), :, :])
    new_img = tf.image.resize(img_corp, [new_size, new_size])/255.0
    return new_img

def load_and_preprocess_image(path, new_size=256):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    return image

data_dir = pathlib.Path(path) 
all_image_paths = [str(path) for path in list(data_dir.glob('*/*.jpg'))] 
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

image_ds = path_ds.map(load_and_preprocess_image)
image_ds = image_ds.map(tf_resize)  # this cause error!!!

The error message is shown below:

File "/tmp/ipykernel_14519/4089718206.py", line 26, in tf_resize  *
        start = math.ceil(abs(w-h))

    TypeError: unsupported operand type(s) for -: 'NoneType' and 'NoneType'

Solution

  • You have a few errors in your code. Try using tf.shape to get the dynamic shape of img:

    import tensorflow as tf
    
    def tf_resize(img, new_size=256):
     
        img_shape = tf.cast(tf.shape(img), tf.float32)
        w = img_shape[1]
        h = img_shape[0]
        start = tf.cast(tf.math.ceil(tf.abs(w-h)), dtype=tf.int32)
    
        w = tf.cast(w, dtype=tf.int32)
        h = tf.cast(h, dtype=tf.int32)
        img_corp = tf.cond(tf.greater(w, h), 
                           lambda: img[:, start:(start+h), :], 
                           lambda: img[start:(start+w), :, :])
        new_img = tf.image.resize(img_corp, [new_size, new_size])/255.0
        return new_img
    
    def load_and_preprocess_image(path, new_size=256):
        image = tf.io.read_file(path)
        image = tf.image.decode_jpeg(image, channels=3)
        return image
    
    all_image_paths = ['image.jpg'] 
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
    
    image_ds = path_ds.map(load_and_preprocess_image)
    image_ds = image_ds.map(tf_resize)  # this cause error!!!
    
    for d in image_ds.take(1):
      print(d.shape)
    
    (256, 256, 3)