I would like to cut the image which are stored as tensorflow dataset to squared image. But it seems that the tensorflow could not allow the map function with ifelse statement. I hope to know whether I could solve this problem. Many thanks in advance.
def tf_resize(img, new_size=256):
h, w,__ = img.shape
start = math.ceil(abs(w-h))
img_corp = tf.cond(tf.constant(w>h, dtype=tf.bool),
lambda: img[:, start:(start+h), :],
lambda: img[start:(start+w), :, :])
new_img = tf.image.resize(img_corp, [new_size, new_size])/255.0
return new_img
def load_and_preprocess_image(path, new_size=256):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
return image
data_dir = pathlib.Path(path)
all_image_paths = [str(path) for path in list(data_dir.glob('*/*.jpg'))]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(load_and_preprocess_image)
image_ds = image_ds.map(tf_resize) # this cause error!!!
The error message is shown below:
File "/tmp/ipykernel_14519/4089718206.py", line 26, in tf_resize *
start = math.ceil(abs(w-h))
TypeError: unsupported operand type(s) for -: 'NoneType' and 'NoneType'
You have a few errors in your code. Try using tf.shape
to get the dynamic shape of img
:
import tensorflow as tf
def tf_resize(img, new_size=256):
img_shape = tf.cast(tf.shape(img), tf.float32)
w = img_shape[1]
h = img_shape[0]
start = tf.cast(tf.math.ceil(tf.abs(w-h)), dtype=tf.int32)
w = tf.cast(w, dtype=tf.int32)
h = tf.cast(h, dtype=tf.int32)
img_corp = tf.cond(tf.greater(w, h),
lambda: img[:, start:(start+h), :],
lambda: img[start:(start+w), :, :])
new_img = tf.image.resize(img_corp, [new_size, new_size])/255.0
return new_img
def load_and_preprocess_image(path, new_size=256):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
return image
all_image_paths = ['image.jpg']
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(load_and_preprocess_image)
image_ds = image_ds.map(tf_resize) # this cause error!!!
for d in image_ds.take(1):
print(d.shape)
(256, 256, 3)