Search code examples
pythondeep-learningneural-networkpytorch

How to implement a diagonal data for a linear layer in pytorch


I would like to have a network in pytorch that only scale the data.

The mathematical notations for my request is:

which means that if my input is [1, 2] and my output is [2, 6]. then the linear layer will look like this:

[ [ 2, 0],
  [ 0, 3] ].

I have this network written in pytorch:

class ScalingNetwork(nn.Module):
    def __init__(self, input_shape, output_shape):
        super().__init__()
        self.linear_layer = nn.Linear(in_features=input_shape, out_features=output_shape)
        self.mask = torch.diag(torch.ones(input_shape))
        self.linear_layer.weight.data = self.linear_layer.weight * self.mask
        self.linear_layer.weight.requires_grad = True

    def get_tranformation_matrix(self):
        return self.linear_layer.weight


    def forward(self, X):
        X = self.linear_layer(X)
        return X

But at the end of the training, my self.linear is not diagonal. What am I doing wrong?


Solution

  • It seems like an apparent constraint here is the fact that self.linear_layer needs to be a squared matrix. You can use the diagonal matrix self.mask to zero out all non-diagonal elements in the forward pass:

    class ScalingNetwork(nn.Module):
        def __init__(self, in_features):
            super().__init__()
            self.linear = nn.Linear(in_features, in_features, bias=False)
            self.mask = torch.eye(in_features, dtype=bool)
    
        def forward(self, x):
            self.linear.weight.data *= self.mask
            print(self.linear.weight)
            x = self.linear(x)
            return x
    

    For instance:

    >>> m = ScalingNetwork(5)
    
    >>> m(torch.rand(1,5))
    Parameter containing:
    tensor([[-0.2987, -0.0000, -0.0000, -0.0000, -0.0000],
            [ 0.0000, -0.1042, -0.0000, -0.0000, -0.0000],
            [-0.0000,  0.0000, -0.4267,  0.0000, -0.0000],
            [ 0.0000, -0.0000, -0.0000,  0.1758,  0.0000],
            [ 0.0000,  0.0000,  0.0000, -0.0000, -0.3208]], requires_grad=True)
    tensor([[-0.1032, -0.0087, -0.1709,  0.0035, -0.1496]], grad_fn=<MmBackward0>)