Search code examples
pythonpytorchtorch

How to use custom torch.autograd.Function in nn.Sequential model


Is there any way that I can use custom torch.autograd.Function in a nn.Sequential object or should I use explicitly an nn.Module object with forward function. Specifically I am trying to implement a sparse autoencoder and I need to add L1 distance of the code(hidden representation) to the loss. I have defined custom torch.autograd.Function L1Penalty below then tried to use it inside a nn.Sequential object as below. However when I run I got the error TypeError: __main__.L1Penalty is not a Module subclass How can I solve this issue?

class L1Penalty(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, l1weight = 0.1):
        ctx.save_for_backward(input)
        ctx.l1weight = l1weight
        return input, None

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_variables
        grad_input = input.clone().sign().mul(ctx.l1weight)
        grad_input+=grad_output
        return grad_input
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 6),
    nn.ReLU(),
    # sparsity
    L1Penalty(),
    nn.Linear(6, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU()
).to(device)

Solution

  • The right way to do that would be this

    import torch, torch.nn as nn
    
    class L1Penalty(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, l1weight = 0.1):
            ctx.save_for_backward(input)
            ctx.l1weight = l1weight
            return input
    
        @staticmethod
        def backward(ctx, grad_output):
            input, = ctx.saved_variables
            grad_input = input.clone().sign().mul(ctx.l1weight)
            grad_input+=grad_output
            return grad_input
    

    Creating a Lambda class that acts as a wrapper

    class Lambda(nn.Module):
        """
        Input: A Function
        Returns : A Module that can be used
            inside nn.Sequential
        """
        def __init__(self, func):
            super().__init__()
            self.func = func
    
        def forward(self, x): return self.func(x)
    

    TA-DA!

    model = nn.Sequential(
        nn.Linear(10, 10),
        nn.ReLU(),
        nn.Linear(10, 6),
        nn.ReLU(),
        # sparsity
        Lambda(L1Penalty.apply),
        nn.Linear(6, 10),
        nn.ReLU(),
        nn.Linear(10, 10),
        nn.ReLU())
    
    a = torch.rand(50,10)
    b = model(a)
    print(b.shape)