Search code examples
machine-learningdeep-learningneural-networkpytorchbackpropagation

PyTorch Boolean - Stop Backpropagation?


I need to create a Neural Network where I use binary gates to zero-out certain tensors, which are the output of disabled circuits.

To improve runtime speed, I was looking forward to use torch.bool binary gates to stop backpropagation along disabled circuits in the network. However, I created a small experiment using the official PyTorch example for the CIFAR-10 dataset, and the runtime speed is exactly the same for any values for gate_A and gate_B: (this means that the idea is not working)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1a = nn.Conv2d(3, 6, 5)
        self.conv2a = nn.Conv2d(6, 16, 5)
        self.conv1b = nn.Conv2d(3, 6, 5)
        self.conv2b = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        # Only one gate is supposed to be enabled at random
        # However, for the experiment, I fixed the values to [1,0] and [1,1]
        choice  =  randint(0,1)
        gate_A  =  torch.tensor(choice   ,dtype = torch.bool) 
        gate_B  =  torch.tensor(1-choice ,dtype = torch.bool) 
        
        a = self.pool(F.relu(self.conv1a(x)))
        a = self.pool(F.relu(self.conv2a(a)))
        
        b = self.pool(F.relu(self.conv1b(x)))
        b = self.pool(F.relu(self.conv2b(b)))
        
        a *= gate_A
        b *= gate_B
        x  = torch.cat( [a,b], dim = 1 )
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

How can i define gate_A and gate_B in such a way that backpropagation effectively stops if they are zero?

PS. Changing concatenation dynamically at runtime would also change which weights are assigned to every module. (for example, the weights associated to a could be assigned to b in another pass, disrupting how the network operates).


Solution

  • You could use torch.no_grad (the code below can probably be made more concise):

    def forward(self, x):
            # Only one gate is supposed to be enabled at random
            # However, for the experiment, I fixed the values to [1,0] and [1,1]
            choice  =  randint(0,1)
            gate_A  =  torch.tensor(choice   ,dtype = torch.bool) 
            gate_B  =  torch.tensor(1-choice ,dtype = torch.bool) 
            
            if choice:
                a = self.pool(F.relu(self.conv1a(x)))
                a = self.pool(F.relu(self.conv2a(a)))
                a *= gate_A
                
                with torch.no_grad(): # disable gradient computation
                    b = self.pool(F.relu(self.conv1b(x)))
                    b = self.pool(F.relu(self.conv2b(b)))
                    b *= gate_B
            else:
                with torch.no_grad(): # disable gradient computation
                    a = self.pool(F.relu(self.conv1a(x)))
                    a = self.pool(F.relu(self.conv2a(a)))
                    a *= gate_A
                
                b = self.pool(F.relu(self.conv1b(x)))
                b = self.pool(F.relu(self.conv2b(b)))
                b *= gate_B
    
            x  = torch.cat( [a,b], dim = 1 )
            
            x = torch.flatten(x, 1) # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    

    On a second look, I think the following is a simpler solution to the specific problem:

    def forward(self, x):
            # Only one gate is supposed to be enabled at random
            # However, for the experiment, I fixed the values to [1,0] and [1,1]
            choice  =  randint(0,1)
    
            if choice:
                a = self.pool(F.relu(self.conv1a(x)))
                a = self.pool(F.relu(self.conv2a(a)))
                b = torch.zeros(shape_of_conv_output) # replace shape of conv output here
            else:
                b = self.pool(F.relu(self.conv1b(x)))
                b = self.pool(F.relu(self.conv2b(b)))
                a = torch.zeros(shape_of_conv_output) # replace shape of conv output here
           
            x  = torch.cat( [a,b], dim = 1 )
            
            x = torch.flatten(x, 1) # flatten all dimensions except batch
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x