Search code examples
tensorflowtensorboard

How to visualize images in tensorboard using tf.summary.image() in [batch, channels, width, height] (NCHW) order?


First, I want to reshape the 2-D to 4-D tensor using tf.reshape().
I thought tf.reshape() will transform
[batch, array] -> [batch, width, height, channels] (NHWC) order
but in practice it transformed
[batch, array] -> [batch, channels, width, height] (NCHW) order

Example:

a = np.array([[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16],[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]])
print(a.shape)

# [batch_size, channels, height, width]
b = sess.run(tf.reshape(a, shape=[2, 3, 4, 4]))

# [batch_size, height, width, channels]
c = sess.run(tf.reshape(a, shape=[2, 4, 4, 3]))

print(b)
print('*******')
print(c)

The result was:

(2, 48)
[[[[ 1  2  3  4]
   [ 5  6  7  8]
   [ 9 10 11 12]
   [13 14 15 16]]

  [[ 1  2  3  4]
   [ 5  6  7  8]
   [ 9 10 11 12]
   [13 14 15 16]]

  [[ 1  2  3  4]
   [ 5  6  7  8]
   [ 9 10 11 12]
   [13 14 15 16]]]


 [[[ 1  2  3  4]
   [ 5  6  7  8]
   [ 9 10 11 12]
   [13 14 15 16]]

  [[ 1  2  3  4]
   [ 5  6  7  8]
   [ 9 10 11 12]
   [13 14 15 16]]

  [[ 1  2  3  4]
   [ 5  6  7  8]
   [ 9 10 11 12]
   [13 14 15 16]]]]
*******
[[[[ 1  2  3]
   [ 4  5  6]
   [ 7  8  9]
   [10 11 12]]

  [[13 14 15]
   [16  1  2]
   [ 3  4  5]
   [ 6  7  8]]

  [[ 9 10 11]
   [12 13 14]
   [15 16  1]
   [ 2  3  4]]

  [[ 5  6  7]
   [ 8  9 10]
   [11 12 13]
   [14 15 16]]]


 [[[ 1  2  3]
   [ 4  5  6]
   [ 7  8  9]
   [10 11 12]]

  [[13 14 15]
   [16  1  2]
   [ 3  4  5]
   [ 6  7  8]]

  [[ 9 10 11]
   [12 13 14]
   [15 16  1]
   [ 2  3  4]]

  [[ 5  6  7]
   [ 8  9 10]
   [11 12 13]
   [14 15 16]]]]

So, I changed data_format='channels_first' for conv and pooling layers to using the reshaped tensor in NCHW order. In fact, the training was good. --verbose: it gave better result, as mention by @mrry in here, and I think it could be understandable because NCHW is the default order of cuDNN.

However, I can not add image to summary using tf.summary.image(), which is documented here, because the required tensor shape should be in [batch_size, height, width, channels] order.

Moreover, if I train and visualize the input images in [batch, width, height, channels] order, it represented incorrect images. enter image description here
And I worth to mention that, the training result was not as good as using [batch, channels, width, height] order.

There are several questions:
1. Why tf.reshape() transform [batch , array] -> (NCHW) order instead of (NHWC) order ? I tested with both tf CPU and GPU, same result. I also used np.reshape(), also same result. (This's why I think I could misunderstand something here)
2. How can I visualize images in tensorboard using tf.summary.image() in (NCHW) order? (question #2 solved using advice from @Maosi Chen. Thanks)
enter image description here

I've trained the model on GPU (version 1.4), the images are from CIFAR-10 dataset.
Thanks


Solution

  • You can reorder the dimensions by tf.transpose (https://www.tensorflow.org/api_docs/python/tf/transpose).

    Note that perm elements are the dimension indices of the source tensor (a)

    import tensorflow as tf
    import numpy as np
    
    sess = tf.InteractiveSession()
    
    a = np.array([[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16],[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]])
    print(a.shape)
    
    # [batch_size, channels, height, width]
    b = sess.run(tf.reshape(a, shape=[2, 3, 4, 4]))
    
    # [batch_size, height, width, channels]
    c = sess.run(tf.transpose(b, perm=[0, 2, 3, 1]))
    
    print(b)
    print('*******')
    print(c)
    

    Results:

    (2, 48) [[[[ 1  2  3  4]    [ 5  6  7  8]    [ 9 10 11 12]    [13 14 15 16]]
    
      [[ 1  2  3  4]    [ 5  6  7  8]    [ 9 10 11 12]    [13 14 15 16]]
    
      [[ 1  2  3  4]    [ 5  6  7  8]    [ 9 10 11 12]    [13 14 15 16]]]
    
    
     [[[ 1  2  3  4]    [ 5  6  7  8]    [ 9 10 11 12]    [13 14 15 16]]
    
      [[ 1  2  3  4]    [ 5  6  7  8]    [ 9 10 11 12]    [13 14 15 16]]
    
      [[ 1  2  3  4]    [ 5  6  7  8]    [ 9 10 11 12]    [13 14 15 16]]]]
    ******* [[[[ 1  1  1]    [ 2  2  2]    [ 3  3  3]    [ 4  4  4]]
    
      [[ 5  5  5]    [ 6  6  6]    [ 7  7  7]    [ 8  8  8]]
    
      [[ 9  9  9]    [10 10 10]    [11 11 11]    [12 12 12]]
    
      [[13 13 13]    [14 14 14]    [15 15 15]    [16 16 16]]]
    
    
     [[[ 1  1  1]    [ 2  2  2]    [ 3  3  3]    [ 4  4  4]]
    
      [[ 5  5  5]    [ 6  6  6]    [ 7  7  7]    [ 8  8  8]]
    
      [[ 9  9  9]    [10 10 10]    [11 11 11]    [12 12 12]]
    
      [[13 13 13]    [14 14 14]    [15 15 15]    [16 16 16]]]]