Search code examples
pytorchtorchvision

How to normalize a batch of tensors in pytorch in a speed optimized manner


I have a input batch which is a list (size 8) of images (480,640,3), which I would like to convert to Pytorch tensors, normalize with mean and std, and pass to a model as (8,3,480,640). Presently I'm doing the following, which works.

import torch as T
from torchvision import transforms

batch_size=8
height = 480
width = 640
input_shape = (batch_size, 3, height, width)

transform = transforms.Compose([              
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
])


input_batch = [...] # np.ones(480,640,3) * 8

pre_items = [transform(item) for item in input_batch]
pre_items = T.stack(pre_items).to("cuda")

This is obviously not optimal because the preprocessing happens on CPU before being moved to CUDA.

What's the correct way to perform this on GPU on the batch as a whole?

My attempt at a solution was:

import torch as T

batch_size = 8
height = 480
width = 640

mean = T.ones((batch_size, height, width, 3)).to("cuda") * T.tensor([0.485, 0.456, 0.406]).to("cuda")
std = T.ones((batch_size, height, width, 3)).to("cuda") * T.tensor([0.229, 0.224, 0.225]).to("cuda")

input_batch = T.stack([T.tensor(item).to("cuda").float() for item in input_batch])
pre_items = (input_batch - mean)/std
pre_items = T.permute(pre_items, (0,3,1,2))

The output of this script does not match the expected tensor from the bottlenecked solution.


Solution

  • According to OPs clarification, this is a speedy way to peform the normalization on the gpu.

    • Note: I reshape the mean and std variable so that I can multiply it with input_batch without stacking the same value multiple times (this is called broadcasting).
    import torch as T
    import numpy as np
    
    batch_size = 8
    height = 480
    width = 640
    channels = 3
    
    # CPU
    input_batch = np.ones((batch_size, height, width, channels))
    mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 1, -1))  # match input_batch dimension
    std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 1, -1))  # match input_batch dimension
    
    # GPU
    input_batch = T.from_numpy(input_batch).to("cuda")
    mean = T.from_numpy(mean).to("cuda")
    std = T.from_numpy(std).to("cuda")
    
    pre_items = (input_batch - mean)/std