Search code examples
pythonmatrixpytorchparametersneural-network

Pytorch custom network doesn't update the weights matrix


I have written this simple custom network for a linear classifier on MNIST.

The kick is that the model operates trough a global adjacency matrix of the entire network to perform the calculations. The matrix is almost all zeroes, with only the bottom left block being non zero.

The model itself is very basic, only two layers, without any non linearity. The problem is that in the learning process the adjacency matrix does not get updated, so the model does not learn, I don't know why. I have tested my training loop on more standard architectures and all works fine (I am using SGD with cross entropy loss), so the problem must be in how I have specified the class of the network. For me it is crucial to operate trough this global adjacency matrix, and I would like to understand where the problem is, and how to make it work.

class Simple_Direct_Network_Adjacency_Matrix_Implementation_Dim2(nn.Module):
def __init__(self, input_dim , middle_dim, output_dim):
    super().__init__()
    self.input_dim = input_dim
    _ = middle_dim #This is an hack: we want dim 2 now, so this input to the class gets ignored
    self.output_dim = output_dim
    self.total_dim = self.input_dim + self.output_dim

    self.subdiagonal_block = nn.Parameter(torch.empty(self.output_dim, self.input_dim))
    nn.init.normal_(self.subdiagonal_block , mean=0 , std=0.1)

    self.adjacency_matrix = self.make_subdiagonal_matrix().requires_grad_(requires_grad=True)


def make_subdiagonal_matrix(self):
    over_block = torch.zeros(self.input_dim, self.input_dim)
    side_block = torch.zeros(self.total_dim, self.output_dim)

    matrix = torch.cat((over_block , self.subdiagonal_block), 0)
    matrix = torch.cat((matrix, side_block), 1)

    return matrix

def forward(self, batch_of_inputs):
    # Flatten the batch of input images
    flat_inputs = batch_of_inputs.view(-1 , batch_of_inputs.size(0))

    # Append zeros to match
    flat_inputs_total = torch.cat((flat_inputs, torch.zeros(self.output_dim , flat_inputs.size(1))), dim=0)

    # Perform matrix multiplication
    y_total_final = torch.mm(self.adjacency_matrix , flat_inputs_total)

    # Extract logits
    logits = y_total_final[-self.output_dim: , :].t()

    return logits

Note that I have also tryed omitting the requires grad, and nothing changes, I don't know if it is necessary. Also note that the matrix of parameters, the one specified with nn.Parameter() also doesn't change. Note also that moving the construction of the adjacency matrix inside the forward function also seems to not solve the problem..


Solution

  • Your adjacency_matrix isn't updated because it's not a nn.Parameter.

    Your subdiagonal_block isn't updated because it's not used in your forward pass.