Search code examples
pythonpytorch

In mini-batch learning, inputs for Fully Connected Layer is 1 or 2 dim?


In mini-batch learning, is the input a 1D array with no distinction between mini batches? Or is it a 2-dimensional array with mini batches and data?

Specifically, I would like to know how to input testdata, torch.Size([batch_size=20, height=256, width=256]) minibatch into the fully connected layer.

Now mycode

class Net(nn.Module):
    def __init__(self, in_features, out_features):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(in_features, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, out_features)

    def forward(self, x):
        x = self.fc1(x)       
        x = self.relu(x)
        out = self.fc2(x)



# input:torch.Size([batch_size=20, height=256, width=256]) 
model = Net(20*256*256, 20*256*256)

input = torch.flatten(input)
output = model(input)
output = torch.reshape(output, (20, 256, 256))

(Never mind there is no conv layer.)

I am not sure if it is appropriate. Can anyone teach me?


Solution

  • To answer your question: the input to a fully connected layer is a tensor of any number of dimensions, provided the last dimension's size matches the number of input features of the fully connected layer

    There are some issues with your code.

    First, don't use the word input. Input() is built-in python function. Instead, create a torch tensor and pass it to your model:

    input_tensor = torch.randn((20, 256, 256)).flatten()
    model = Net(256*256,256*256)
    output_tensor = model(input_tensor)
    output_tensor = torch.reshape(output_tensor, (BATCH_SIZE, IMG_VER, IMG_HOR))
    

    Second, the number of input and output features of a fully connected layer should not include the batch size. That is,

    model = Net(256*256, 256*256)
    

    instead of

    model = Net(20*256*256, 20*256*256)
    

    Third, it's best not to assume that whoever uses your model knows that they need to change it's shape. Instead, change the shape inside the model. And fourth, the forward() function should return it's output. See this:

    import torch
    from torch import nn
    
    
    class Net(nn.Module):
        def __init__(self, in_features, out_features):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(in_features, 128)
            self.relu = nn.ReLU()
            self.fc2 = nn.Linear(128, out_features)
    
        def forward(self, x):
            x = self.fc1(x.reshape((20, -1))
            x = self.relu(x)
            out = self.fc2(x)
            return out
    
    
    model = Net(256*256, 256*256)
    input_tensor = torch.randn((20, 256, 256))
    output_tensor = model(input_tensor)