Search code examples
pythontensorflowtiff

Tiff workaround for Tensorflow


I need to build a data loader to train a CNN for semantic segmentation using tensorflow. The images are 3-channel tiff training images and 1-channel (grayscale) tiff masks.

So far, I followed this example. They write a function

def parse_image(img_path: str) -> dict:
    image = tf.io.read_file(img_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.uint8)
    mask_path = tf.strings.regex_replace(img_path, "images", "annotations")
    mask_path = tf.strings.regex_replace(mask_path, "jpg", "png")
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.where(mask == 255, np.dtype('uint8').type(0), mask)

    return {'image': image, 'segmentation_mask': mask}

which works well for jpeg and png images. For tiff, however, one has to use tfio.experimental.image.decode_tiff(image) which is very limited and it doesn't work in my case. It spits a lot of errors such as

TIFFReadDirectory: Warning, Unknown field with tag 42112 (0xa480) encountered.

As noted in this answer, I could use a package such as cv2 or PIL.

I tried to implement it as follows:

import cv2
def parse_image(img_path: str) -> dict:
    # read image
    image = cv2.imread(img_path)
    image = tf.convert_to_tensor(image, tf.uint8)
    # read mask
    mask_path = tf.strings.regex_replace(img_path, "X", "y")
    mask_path = tf.strings.regex_replace(mask_path, "X.tif", "y.tif")
    mask = cv2.imread(mask_path)
    mask = tf.convert_to_tensor(mask, tf.uint8)
    return {"image": image, "segmentation_mask": mask}

However, this will only result in

TypeError: in user code:

    <ipython-input-46-41b06b3732aa>:6 parse_image  *
        image = cv2.imread(img_path)

    TypeError: Can't convert object of type 'Tensor' to 'str' for 'filename'

and I suppose there will be many more problems when using non tensorflow functions in this function.

Since I've seen a few older posts about similar problems with tensorflow and tiff I wondered if in the meantime, there is a workaround? E.g., some custom function that is compatible with the rest of tensorflow and can read tiff data?


Solution

  • If you're still willing to use opencv instead, then you can wrap your reading function in a tf.numpy_function. Inside the scope of the function wrapped in the tf.numpy_function, you deal with numpy arrays, so converting the numpy bytestring representation into a regular python string is needed before calling cv2.imread.

    import tensorflow as tf
    import cv2
    
    def parse_with_opencv(image_path):
        return cv2.imread(image_path.decode('UTF-8'))
    
    img_path = ["/path/to/image.tif"]
    
    ds = tf.data.Dataset.from_tensor_slices(img_path).map(
        lambda x: tf.numpy_function(parse_with_opencv, [x], Tout=tf.uint8)
    ) 
    

    Dealing with tf.numpy_function is sometimes a bit frustrating, as error messages can be somewhat cryptic.


    One other option would be to use the decode_tiff function from the tfio module. It is compatible with TensorFlow operations. Of course it has limitations.

    An example:

    import tensorflow as tf
    import tensorflow_io as tfio
    
    def parse_image(image_path):
        filecontent = tf.io.read_file(image_path)
        img = tfio.experimental.image.decode_tiff(filecontent)
        return img
    
    img_path = ["/path/to/image.tif"]
    
    ds = tf.data.Dataset.from_tensor_slices(img_path).map(parse_image)