Search code examples
pythonpytorchtorch

How to 4 dimension PyTorch tensor multiply by 1 dimension tensor?


I'm trying to write function for mixup training. On this site i found some code and adapted to my previous code. But in original code only one random variable is generated for batch (64). But i want random value for every picture in batch. Code with one variable for batch:

def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index,:]
    mixed_y = lam * y + (1 - lam) * y[index,:]

    return mixed_x, mixed_y

x and y for input come from pytorch DataLoader. x input size: torch.Size([64, 3, 256, 256]) y input size: torch.Size([64, 3474])

This code works good. Then I changed it to this:

def mixup_data(x, y):
    batch_size = x.size()[0]
    lam = torch.rand(batch_size)
    index = torch.randperm(batch_size)

    mixed_x = lam[index] * x + (1 - lam[index]) * x[index,:]
    mixed_y = lam[index] * y + (1 - lam[index]) * y[index,:]

    return mixed_x, mixed_y

But it gives an error: RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

How i understand how the code works is it takes first image in batch and multiply by first value in lam tensor (64 values long). How can i do it?


Solution

  • You need to replace the following line:

    lam = torch.rand(batch_size)
    

    by

    lam = torch.rand(batch_size, 1, 1, 1)
    

    With your current code, lam[index] * x multiplication is not possible because lam[index] is of size torch.Size([64]) whereas x is of size torch.Size([64, 3, 256, 256]). So, you need to make the size of lam[index] as torch.Size([64, 1, 1, 1]) so that it becomes broadcastable.

    To cope with the following statement:

    mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]
    

    We can reshape the lam tensor before the statement.

    lam = lam.reshape(batch_size, 1)
    mixed_y = lam[index] * y + (1 - lam[index]) * y[index, :]