Search code examples
pythonmachine-learningpytorchpaddingtensor

Pad torch tensors of different sizes to be equal


I am looking for a way to take an image/target batch for segmentation and return the batch where the image dimensions have been changed to be equal for the whole batch. I have tried this using the code below:

def collate_fn_padd(batch):
    '''
    Padds batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    # separate the image and masks
    image_batch,mask_batch = zip(*batch)

    # pad the images and masks
    image_batch = torch.nn.utils.rnn.pad_sequence(image_batch, batch_first=True)
    mask_batch = torch.nn.utils.rnn.pad_sequence(mask_batch, batch_first=True)

    # rezip the batch
    batch = list(zip(image_batch, mask_batch))

    return batch

However, I get this error:

RuntimeError: The expanded size of the tensor (650) must match the existing size (439) at non-singleton dimension 2.  Target sizes: [3, 650, 650].  Tensor sizes: [3, 406, 439]

How do I efficiently pad the tensors to be of equal dimensions and avoid this issue?


Solution

  • rnn.pad_sequence only pads the sequence dimension, it requires all other dimensions to be equal. You cannot use it to pad images across two dimensions (height and width).

    To pad an image torch.nn.functional.pad can be used, but you need to manually determine the height and width it needs to get padded to.

    import torch.nn.functional as F
    
    # Determine maximum height and width
    # The mask's have the same height and width
    # since they mask the image.
    max_height = max([img.size(1) for img in image_batch])
    max_width = max([img.size(2) for img in image_batch])
    
    image_batch = [
        # The needed padding is the difference between the
        # max width/height and the image's actual width/height.
        F.pad(img, [0, max_width - img.size(2), 0, max_height - img.size(1)])
        for img in image_batch
    ]
    mask_batch = [
        # Same as for the images, but there is no channel dimension
        # Therefore the mask's width is dimension 1 instead of 2
        F.pad(mask, [0, max_width - mask.size(1), 0, max_height - mask.size(0)])
        for mask in mask_batch
    ]
    

    The padding lengths are specified in reverse order of the dimensions, where every dimension has two values, one for the padding at the beginning and one for the padding at the end. For an image with the dimensions [channels, height, width] the padding is given as: [width_beginning, width_end, height_beginning, height_top], which can be reworded to [left, right, top, bottom]. Therefore the code above pads the images to the right and bottom. The channels are left out, because they are not being padded, which also means that the same padding could be directly applied to the masks.