For a nice output in Tensorboard I want to show a batch of input images, corresponding target masks and output masks in a grid. Input images have different size then the masks. Furthermore the images are obviously RGB. From a batch of e.g. 32 or 64 I only want to show the first 4 images.
After some fiddling around I came up with the following example code. Good thing: It works. But I am really not sure if I missed something in Pytorch. It just looks much longer then I expected. Especially the upsampling and transformation to RGB seems wild. But the other transformations I found would not work for a whole batch.
import torch
from torch.autograd import Variable
import torch.nn.functional as FN
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
import time
batch = 32
i_size = 192
o_size = 112
nr_imgs = 4
# Tensorboard init
writer = SummaryWriter('runs/' + time.strftime('%Y%m%d_%H%M%S'))
input_image=Variable(torch.rand(batch,3,i_size,i_size))
target_mask=Variable(torch.rand(batch,o_size,o_size))
output_mask=Variable(torch.rand(batch,o_size,o_size))
# upsample target_mask, add dim to have gray2rgb
tm = FN.upsample(target_mask[:nr_imgs,None], size=[i_size, i_size], mode='bilinear')
tm = torch.cat( (tm,tm,tm), dim=1) # grayscale plane to rgb
# upsample target_mask, add dim to have gray2rgb
om = FN.upsample(output_mask[:nr_imgs,None], size=[i_size, i_size], mode='bilinear')
om = torch.cat( (om,om,om), dim=1) # grayscale plane to rgb
# add up all images and make grid
imgs = torch.cat( ( input_image[:nr_imgs].data, tm.data, om.data ) )
x = vutils.make_grid(imgs, nrow=nr_imgs, normalize=True, scale_each=True)
# Tensorboard img output
writer.add_image('Image', x, 0)
EDIT: Found this on Pytorchs Issues list. Its about Batch support for Transform
. Seems there are no plans to add batch transforms in the future. So my current code might be the best solution for the time being, anyway?
Maybe you can just convert your Tensors to the numpy array (.data.cpu().numpy() ) and use opencv to do upsampling? OpenCV implementation should be quite fast.