Search code examples
pythontensorflowone-hot-encoding

How to convert multi-class one-hot tensor to RGB in TensorFlow?


I have a tensor with shape [None, 128, 128, n_classes]. This is a one-hot tensor, where the last index contains the categorical values for multiple classes (there are n_classes in total). In practice, the last channel has binary values that indicate the class of each pixel: e.g. when a pixel has 1 in the channel C it means it belongs to the class C; this pixel will have 0 elsewhere.

Now, I wish to convert this one-hot tensor to an RGB image, that I want to plot on Tensorboard. Every class has to be associated with a different colour so that it is easier to interpret.

Any idea on how to do that?

Thanks, G.


Edit 2:

Solution added in the answers.


Edit 1:

My current implementation (not working):

def from_one_hot_to_rgb(incoming, palette=None):
    """ Assign a different color to each class in the input tensor """
    if palette is None:
        palette = {
            0: (0, 0, 0),
            1: (31, 12, 33),
            2: (13, 26, 33),
            3: (21, 76, 22),
            4: (22, 54, 66)
        }

    def _colorize(value):
        return palette[value]

    # from one-hot to grayscale:
    cmap = tf.expand_dims(tf.argmax(incoming, axis=-1), axis=-1)

    # flatten input tensor (pixels on the first axis):
    B, W, H, C = get_shape(camp)  # this returns batch_size, 128, 128, 5
    cmap_flat = tf.reshape(cmap, shape=[B * W * H, C])

    # assign a different color to each class:
    cmap = tf.map_fn(lambda pixel:
                     tf.py_func(_colorize, inp=[pixel], Tout=tf.int64),
                     cmap_flat)

    # back to original shape, but RGB output:
    cmap = tf.reshape(cmap, shape=[B, W, H, 3])

    return tf.cast(cmap, dtype=tf.float32)

Solution

  • Solution 1 (slow)

    A possible solution, similar to the initial code is the following. Notice that this can be very slow because of a known problem of TensorFlow tf.map_fn

    def from_one_hot_to_rgb_bkup(incoming, palette=None):
    
        if palette is None:
            palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
                ((0, 0, 0),
                (31, 12, 33),
                (13, 26, 33),
                (21, 76, 22),
                (22, 54, 66))
            )}
    
        # from one-hot to grayscale:
        B, W, H, _ = get_shape(incoming)
        gray = tf.reshape(tf.argmax(incoming, axis=-1, output_type=tf.int32), [-1, 1], name='flatten')
    
        # assign colors to each class
        rgb = tf.map_fn(lambda pixel:
                        tf.py_func(lambda value: palette[int(value)], inp=[pixel], Tout=tf.int32),
                        gray, name='colorize')
    
        # back to original shape, but RGB output:
        rgb = tf.reshape(rgb, shape=[B, W, H, 3], name='back_to_rgb')
    
        return tf.cast(rgb, dtype=tf.float32)
    

    Solution 2 (fast)

    Based on this answer, a much faster solution can be using tf.gather:

    def from_one_hot_to_rgb_bkup(incoming, palette=None):
    
        if palette is None:
            palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
                ((0, 0, 0),
                (31, 12, 33),
                (13, 26, 33),
                (21, 76, 22),
                (22, 54, 66))
            )}
    
        _, W, H, _ = get_shape(incoming)
        palette = tf.constant(palette, dtype=tf.uint8)
        class_indexes = tf.argmax(incoming, axis=-1)
    
        class_indexes = tf.reshape(class_indexes, [-1])
        color_image = tf.gather(palette, class_indexes)
        color_image = tf.reshape(color_image, [-1, W, H, 3])
    
        color_image = tf.cast(color_image, dtype=tf.float32)