I have a tensor with shape [None, 128, 128, n_classes]
. This is a one-hot tensor, where the last index contains the categorical values for multiple classes (there are n_classes
in total).
In practice, the last channel has binary values that indicate the class of each pixel: e.g. when a pixel has 1 in the channel C it means it belongs to the class C; this pixel will have 0 elsewhere.
Now, I wish to convert this one-hot tensor to an RGB image, that I want to plot on Tensorboard. Every class has to be associated with a different colour so that it is easier to interpret.
Any idea on how to do that?
Thanks, G.
Edit 2:
Solution added in the answers.
Edit 1:
My current implementation (not working):
def from_one_hot_to_rgb(incoming, palette=None):
""" Assign a different color to each class in the input tensor """
if palette is None:
palette = {
0: (0, 0, 0),
1: (31, 12, 33),
2: (13, 26, 33),
3: (21, 76, 22),
4: (22, 54, 66)
}
def _colorize(value):
return palette[value]
# from one-hot to grayscale:
cmap = tf.expand_dims(tf.argmax(incoming, axis=-1), axis=-1)
# flatten input tensor (pixels on the first axis):
B, W, H, C = get_shape(camp) # this returns batch_size, 128, 128, 5
cmap_flat = tf.reshape(cmap, shape=[B * W * H, C])
# assign a different color to each class:
cmap = tf.map_fn(lambda pixel:
tf.py_func(_colorize, inp=[pixel], Tout=tf.int64),
cmap_flat)
# back to original shape, but RGB output:
cmap = tf.reshape(cmap, shape=[B, W, H, 3])
return tf.cast(cmap, dtype=tf.float32)
A possible solution, similar to the initial code is the following. Notice that this can be very slow because of a known problem of TensorFlow tf.map_fn
def from_one_hot_to_rgb_bkup(incoming, palette=None):
if palette is None:
palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
((0, 0, 0),
(31, 12, 33),
(13, 26, 33),
(21, 76, 22),
(22, 54, 66))
)}
# from one-hot to grayscale:
B, W, H, _ = get_shape(incoming)
gray = tf.reshape(tf.argmax(incoming, axis=-1, output_type=tf.int32), [-1, 1], name='flatten')
# assign colors to each class
rgb = tf.map_fn(lambda pixel:
tf.py_func(lambda value: palette[int(value)], inp=[pixel], Tout=tf.int32),
gray, name='colorize')
# back to original shape, but RGB output:
rgb = tf.reshape(rgb, shape=[B, W, H, 3], name='back_to_rgb')
return tf.cast(rgb, dtype=tf.float32)
Based on this answer, a much faster solution can be using tf.gather
:
def from_one_hot_to_rgb_bkup(incoming, palette=None):
if palette is None:
palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
((0, 0, 0),
(31, 12, 33),
(13, 26, 33),
(21, 76, 22),
(22, 54, 66))
)}
_, W, H, _ = get_shape(incoming)
palette = tf.constant(palette, dtype=tf.uint8)
class_indexes = tf.argmax(incoming, axis=-1)
class_indexes = tf.reshape(class_indexes, [-1])
color_image = tf.gather(palette, class_indexes)
color_image = tf.reshape(color_image, [-1, W, H, 3])
color_image = tf.cast(color_image, dtype=tf.float32)