Let's say I have a segmentation model (model
) and I want to batch transform its predictions to pillow images. And, for simplicity, let's say everything is done on CPU (no GPU involved).
If I do:
import torch
from torchvision.transforms import ToPILImage
transform = ToPILImage()
model.eval()
for i, (x, y) in enumerate(dataloader):
y_hat = torch.sigmoid(model(x)) # returns a tensor (batch_size, 1, H, W)
y_hat = (y_hat > 0.5).float()
img = transform(y_hat)
I get:
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.
Fair enough. Let me try using vmap
to transform it as a batch:
import torch
from torchvision.transforms import ToPILImage
transform = ToPILImage()
batch_transform = torch.func.vmap(transform)
model.eval()
for i, (x, y) in enumerate(dataloader):
y_hat = torch.sigmoid(model(x)) # returns a tensor (batch_size, 1, H, W)
y_hat = (y_hat > 0.5).float()
img = batch_transform(y_hat)
That produces the following error:
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Why does this behave this way? Does it have anything to do with the function I've chosen to vmap? I've followed the pattern that's in the documentation and this should work. How can I perform this operation to a batch of images?
As the error message suggests, the ToPILImage
transform operates on tensors that are either 2D (H,W)
or 4D (C, H, W)
. This means you have to iterate over the batch elements and apply the transform:
imgs = [transform(t) for t in y_hat]
Alternatively, you can use torchvision.utils.make_grid
to construct a grid from a list of tensors:
img = transform(make_grid(y_hat))
The convenient torchvision.utils.save_image
utility function is there to combine make_grid
, the PIL.Image
conversion, and saving to file system in one call:
save_image(y_hat, 'pred.jpg')