I need to build a data loader to train a CNN for semantic segmentation using tensorflow. The images are 3-channel tiff training images and 1-channel (grayscale) tiff masks.
So far, I followed this example. They write a function
def parse_image(img_path: str) -> dict:
image = tf.io.read_file(img_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.uint8)
mask_path = tf.strings.regex_replace(img_path, "images", "annotations")
mask_path = tf.strings.regex_replace(mask_path, "jpg", "png")
mask = tf.io.read_file(mask_path)
mask = tf.image.decode_png(mask, channels=1)
mask = tf.where(mask == 255, np.dtype('uint8').type(0), mask)
return {'image': image, 'segmentation_mask': mask}
which works well for jpeg and png images. For tiff, however, one has to use tfio.experimental.image.decode_tiff(image)
which is very limited and it doesn't work in my case. It spits a lot of errors such as
TIFFReadDirectory: Warning, Unknown field with tag 42112 (0xa480) encountered.
As noted in this answer, I could use a package such as cv2
or PIL
.
I tried to implement it as follows:
import cv2
def parse_image(img_path: str) -> dict:
# read image
image = cv2.imread(img_path)
image = tf.convert_to_tensor(image, tf.uint8)
# read mask
mask_path = tf.strings.regex_replace(img_path, "X", "y")
mask_path = tf.strings.regex_replace(mask_path, "X.tif", "y.tif")
mask = cv2.imread(mask_path)
mask = tf.convert_to_tensor(mask, tf.uint8)
return {"image": image, "segmentation_mask": mask}
However, this will only result in
TypeError: in user code:
<ipython-input-46-41b06b3732aa>:6 parse_image *
image = cv2.imread(img_path)
TypeError: Can't convert object of type 'Tensor' to 'str' for 'filename'
and I suppose there will be many more problems when using non tensorflow functions in this function.
Since I've seen a few older posts about similar problems with tensorflow and tiff I wondered if in the meantime, there is a workaround? E.g., some custom function that is compatible with the rest of tensorflow and can read tiff data?
If you're still willing to use opencv
instead, then you can wrap your reading function in a tf.numpy_function
. Inside the scope of the function wrapped in the tf.numpy_function
, you deal with numpy arrays, so converting the numpy bytestring representation into a regular python string is needed before calling cv2.imread
.
import tensorflow as tf
import cv2
def parse_with_opencv(image_path):
return cv2.imread(image_path.decode('UTF-8'))
img_path = ["/path/to/image.tif"]
ds = tf.data.Dataset.from_tensor_slices(img_path).map(
lambda x: tf.numpy_function(parse_with_opencv, [x], Tout=tf.uint8)
)
Dealing with tf.numpy_function
is sometimes a bit frustrating, as error messages can be somewhat cryptic.
One other option would be to use the decode_tiff
function from the tfio
module. It is compatible with TensorFlow operations. Of course it has limitations.
An example:
import tensorflow as tf
import tensorflow_io as tfio
def parse_image(image_path):
filecontent = tf.io.read_file(image_path)
img = tfio.experimental.image.decode_tiff(filecontent)
return img
img_path = ["/path/to/image.tif"]
ds = tf.data.Dataset.from_tensor_slices(img_path).map(parse_image)