I am trying to implement an augmentation function to my images and masks, I have defined the augmentations like below:
if config.AUG == "PRIMEAugmentation":
augmentations = [autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y]
and the function is like below:
import torch
from torch.distributions import Dirichlet, Beta
class PRIMEAugmentation:
def __init__(self, mixture_width=3, mixture_depth=-1):
self.mixture_width = mixture_width
self.mixture_depth = mixture_depth
def __call__(self, x, mask):
x = torch.from_numpy(x)
mask = torch.from_numpy(mask)
ws = Dirichlet(torch.ones(self.mixture_width)).sample((x.shape[0],))
m = Beta(torch.ones(1), torch.ones(1)).sample().expand(x.shape[0], 1, 1, 1)
x_aug = torch.zeros_like(x)
mask_aug = torch.zeros_like(mask)
for i in range(self.mixture_width):
x_i = x.clone()
mask_i = mask.clone()
for d in range(self.mixture_depth):
op = torch.randint(len(self.augmentations), size=(x.shape[0],)).tolist()
x_i, mask_i = self.augmentations[op](x_i, mask_i)
x_aug += ws[:, i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(1) * x_i.unsqueeze(1)
mask_aug += ws[:, i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(1) * mask_i.unsqueeze(1)
mixed = (1 - m) * x + m * x_aug.sum(dim=1)
mixed_mask = (1 - m) * mask + m * mask_aug.sum(dim=1)
return mixed.numpy(), mixed_mask.numpy()
and I have called it like the following way:
augmenter_PRIMEAugmentation = aug_lib_new.PRIMEAugmentation()
import os
def image_mask_transformation(image,mask,img_trans,aug_trans=False):
transformed = img_trans(image=image, mask=mask)
image = transformed["image"]
mask = transformed["mask"]
if aug_trans in augmenter_list:
image,mask = eval('augmenter_'+aug_trans)(image, mask)
but I am getting an error:
x_aug += ws[:, i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(1) *
x_i.unsqueeze(1) RuntimeError: output with shape [320, 320, 3] doesn't
match the broadcast shape [320, 320, 320, 320, 3]
Assuming that x_i
has the same shape as x_aug
, and mask_i
has the same shape as mask_aug
. No broadcasting is required for those. However, you said ws[:, i]
was 1D of shape (320,)
. That means you need to unsqueeze two dimensions, no more.
A double indexing with None
should work:
x_aug += ws[:, i][:,None,None] * x_i
mask_aug += ws[:, i][:,None,None] * mask_i