Search code examples
deep-learningpytorchcomputer-visiondata-augmentationimage-augmentation

RuntimeError: output with shape [320, 320, 3] doesn't match the broadcast shape [320, 320, 320, 320, 3]


I am trying to implement an augmentation function to my images and masks, I have defined the augmentations like below:

if config.AUG == "PRIMEAugmentation":
    augmentations = [autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y]

and the function is like below:

import torch
from torch.distributions import Dirichlet, Beta

class PRIMEAugmentation:
    def __init__(self, mixture_width=3, mixture_depth=-1):
        self.mixture_width = mixture_width
        self.mixture_depth = mixture_depth

    def __call__(self, x, mask):
        x = torch.from_numpy(x)
        mask = torch.from_numpy(mask)
        ws = Dirichlet(torch.ones(self.mixture_width)).sample((x.shape[0],))
        m = Beta(torch.ones(1), torch.ones(1)).sample().expand(x.shape[0], 1, 1, 1)

        x_aug = torch.zeros_like(x)
        mask_aug = torch.zeros_like(mask)
        for i in range(self.mixture_width):
            x_i = x.clone()
            mask_i = mask.clone()
            for d in range(self.mixture_depth):
                op = torch.randint(len(self.augmentations), size=(x.shape[0],)).tolist()
                x_i, mask_i = self.augmentations[op](x_i, mask_i)
            x_aug += ws[:, i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(1) * x_i.unsqueeze(1)
            mask_aug += ws[:, i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(1) * mask_i.unsqueeze(1)

        mixed = (1 - m) * x + m * x_aug.sum(dim=1)
        mixed_mask = (1 - m) * mask + m * mask_aug.sum(dim=1)
        return mixed.numpy(), mixed_mask.numpy()

and I have called it like the following way:

augmenter_PRIMEAugmentation = aug_lib_new.PRIMEAugmentation()

import os

def image_mask_transformation(image,mask,img_trans,aug_trans=False):
    transformed = img_trans(image=image, mask=mask)
    image = transformed["image"]
    mask = transformed["mask"]

    if aug_trans in augmenter_list:
        image,mask = eval('augmenter_'+aug_trans)(image, mask)

but I am getting an error:

     x_aug += ws[:, i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(1) *
 x_i.unsqueeze(1) RuntimeError: output with shape [320, 320, 3] doesn't
 match the broadcast shape [320, 320, 320, 320, 3]

Solution

  • Assuming that x_i has the same shape as x_aug, and mask_i has the same shape as mask_aug. No broadcasting is required for those. However, you said ws[:, i] was 1D of shape (320,). That means you need to unsqueeze two dimensions, no more.

    A double indexing with None should work:

    x_aug += ws[:, i][:,None,None] * x_i
    mask_aug += ws[:, i][:,None,None] * mask_i