Search code examples
deep-learningpytorchgenerator

using batches results in errors (3D or 4D tensor expected for input, expected input to have 9 channels, but got 4 channels )


I have this code that works fine.

import torch
import torch.nn as nn
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# attention module for FRU
class attention_FRU(nn.Module):
    def __init__(self, num_channels_down, pad='reflect'):
        super(attention_FRU, self).__init__()
        # layers to generate conditional convolution weights
        self.gen_se_weights1 = nn.Sequential(
            nn.Conv2d(num_channels_down, num_channels_down, 1, padding_mode=pad),
            nn.LeakyReLU(0.2, inplace=True), # Dont use Softplus here
            nn.Sigmoid())

        # create conv layers
        self.conv_1 = nn.Conv2d(num_channels_down, num_channels_down, 1, padding_mode=pad)
        self.norm_1 = nn.BatchNorm2d(num_channels_down, affine=False)
        self.actvn = nn.LeakyReLU(0.2, inplace=True)
        # self.actvn = nn.Softplus()

    def forward(self, guide, x):
        se_weights1 = self.gen_se_weights1(guide)
        dx = self.conv_1(x)
        dx = self.norm_1(dx)
        dx = torch.mul(dx, se_weights1)
        out = self.actvn(dx)
        return out

class hs_net(nn.Module):
    def __init__(self, ym_channel, yh_channel, num_channels_down, num_channels_up, num_channels_skip,
                 filter_size_down, filter_size_up, filter_skip_size):
        super(hs_net,self).__init__()

        self.FRU = attention_FRU(num_channels_down)
        self.up_bic = nn.Upsample(scale_factor=4, mode='bicubic')
        self.up_trans = nn.ConvTranspose2d(yh_channel,yh_channel,filter_size_down,stride=4,padding=1)

        self.guide_ms = nn.Sequential(
            nn.Conv2d(ym_channel, num_channels_down, filter_size_down, padding ='same',padding_mode='reflect'),
            nn.BatchNorm2d(num_channels_down),
            # nn.LeakyReLU(0.2))
            nn.Softplus())

        self.enc = nn.Sequential(
            nn.Conv2d(num_channels_down, num_channels_down, filter_size_down,padding='same', padding_mode='reflect'),
            nn.BatchNorm2d(num_channels_down),
            # nn.LeakyReLU(0.2))
            nn.Softplus())

        self.skip = nn.Sequential(
            nn.Conv2d(num_channels_down, num_channels_skip, filter_skip_size, padding ='same', padding_mode='reflect'),
            nn.BatchNorm2d(num_channels_skip),
            # nn.LeakyReLU(0.2))
            nn.Softplus())

        self.dc = nn.Sequential(
            nn.Conv2d((num_channels_skip + num_channels_up), num_channels_up, filter_size_up,padding='same',padding_mode='reflect'),
            nn.BatchNorm2d(num_channels_up),
            # nn.LeakyReLU(0.2))
            nn.Softplus())
        self.out_layer = nn.Sequential(
            nn.Conv2d(num_channels_up, yh_channel, 1, padding_mode='reflect'),
            nn.Sigmoid())
        self.conv_hs = nn.Sequential(
            nn.Conv2d(yh_channel,num_channels_down,filter_size_down, padding = 'same',padding_mode = 'reflect'))
            # nn.BatchNorm2d(num_channels_down),
            # nn.Softplus())
        self.conv_bn = nn.Sequential(
            nn.Conv2d(num_channels_down,num_channels_down,filter_size_down, padding = 'same',padding_mode = 'reflect'),
            # nn.BatchNorm2d(num_channels_down),
            # # nn.LeakyReLU(0.2))
            nn.Softplus())
        self.ym_channels= ym_channel
    def forward(self, inputs):
        ym = inputs[:, :self.ym_channels, :, :]
        yh = inputs[:, self.ym_channels:, :, :]
        
        ym_en0 = self.guide_ms(ym)
        ym_en1 = self.enc(ym_en0)
        ym_en2 = self.enc(ym_en1)
        ym_en3 = self.enc(ym_en2)
        ym_en4 = self.enc(ym_en3)

        ym_dc0 = self.enc(ym_en4)
        ym_dc1 = self.enc(ym_dc0)
        ym_dc2 = self.dc(torch.cat((self.skip(ym_en4), ym_dc1), dim=1))
        ym_dc3 = self.dc(torch.cat((self.skip(ym_en3), ym_dc2), dim=1))
        ym_dc4 = self.dc(torch.cat((self.skip(ym_en2), ym_dc3), dim=1))
        ym_dc5 = self.dc(torch.cat((self.skip(ym_en1), ym_dc4), dim=1))
        ym_dc6 = self.dc(torch.cat((self.skip(ym_en0), ym_dc5), dim=1))

        
        yh_6 = self.FRU(self.conv_hs(yh), ym_dc0)
        yh_7 = self.FRU(self.conv_bn(yh_6), ym_dc1)
        yh_8 = self.FRU(self.conv_bn(yh_7), ym_dc2)
        yh_9 = self.FRU(self.conv_bn(yh_8), ym_dc3)
        yh_10 = self.FRU(self.conv_bn(yh_9), ym_dc4)
        yh_11 = self.FRU(self.conv_bn(yh_10), ym_dc5)
        yh_12 = self.FRU(self.conv_bn(yh_11), ym_dc6)

       

        out = self.out_layer(yh_12)

        return out

class MyDataGenerator(torch.utils.data.Dataset):
    def __init__(self, data, batch_size):
          self.data = torch.squeeze(data)
          self.batch_size = batch_size
    
    def __len__(self):
          return int(np.ceil(self.data.shape[0] / self.batch_size))
    
    def __getitem__(self, index):
          # Calculate the start and end indices of the batch
          start_idx = index * self.batch_size
          end_idx = min((index + 1) * self.batch_size, self.data.shape[0])
          # Select samples for the batch
          batch_data = self.data[start_idx:end_idx]
          return batch_data
    
    
inputs = torch.from_numpy(np.random.rand(1, 181, 512, 512)).to(device,dtype=torch.float)


num_iter = 1000
LR = 0.001
n_channels=172

net=hs_net(ym_channel=9, 
           yh_channel=n_channels,
           num_channels_down=16, 
           num_channels_up=16,
           num_channels_skip=16,
           filter_size_down=1,
           filter_size_up=1,
           filter_skip_size=1).to(device)

msi = torch.from_numpy(np.random.rand(9, 512, 512)).to(device,dtype=torch.float)
hsi = torch.from_numpy(np.random.rand(172, 128, 128)).to(device,dtype=torch.float)

targets=[msi,hsi]
gband=msi.shape[0]
optimizer = torch.optim.Adam(net.parameters(), lr=LR, eps=1e-3, amsgrad=True)

def hs_loss(model,inputs,targets):
    ym = targets[0] 
    yh=targets[1] 
    
    xhat=model(inputs)
    return xhat
   
        
for it in range(num_iter):
    optimizer.zero_grad()
    loss, out_HR = hs_loss(net, inputs, targets)
        

but when I try to use batches:

  train_set = MyDataGenerator(inputs, batch_size=4)
  data_generator = torch.utils.data.DataLoader(train_set)
    
    for it in range(num_iter):
        optimizer.zero_grad()
        for batch in data_generator:
            loss, out_HR = hs_loss(net, batch, targets)

it gives me:

Given groups=1, weight of size [16, 9, 1, 1], expected input[1, 4, 512, 512] to have 9 channels, but got 4 channels instead

If I try :

train_set = MyDataGenerator(inputs, batch_size=9)

I receive:

3D or 4D (batch mode) tensor expected for input, but got: [ torch.cuda.FloatTensor{1,0,512,512} ]

Solution

  • The issue is that you are not using torch.data.utils.Dataset, please read the documentation page for more information. You don't have to worry about assembling the batch yourself, the point is for your dataset's __getitem__ to return a single element at a time. It's the job of torch.data.utils.DataLoader to collate the data properly depending on a batch size. Here is a demonstration following your example:

    class MyDataset(Dataset):
        def __init__(self, data):
            self.data = torch.squeeze(data)
        
        def __len__(self):
            return len(self.data)
        
        def __getitem__(self, index):
              return self.data[index]
    

    First, define dummy data (make sure the number of elements is larger than 1 of course). Then initialize the dataset and wrap it with a data loader:

    >>> inputs = torch.rand(10, 181, 512, 512)
    >>> dataloader = DataLoader(MyDataset(inputs), batch_size=4)
    

    Now, you can iterate using dataloader which provides a sampler to navigate through the dataset:

    >>> for batch in dataloader:
    ...     # batch has a shape of (4, 181, 512, 512)