Search code examples
deep-learningpytorchrecurrent-neural-networkgradient-descent

Why model with one GRU layer return zero gradients?


I'm trying to compare between 2 models in order to learn about the behaviour of the gradients.

import torch
import torch.nn as nn
import torchinfo

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()        
        self.Identity =  nn.Identity ()
        self.GRU      = nn.GRU(input_size=3, hidden_size=32, num_layers=2, batch_first=True)
        self.fc       = nn.Linear(32, 5)
        
    def forward(self, input_series):
                
        self.Identity(input_series)
        
        output, h = self.GRU(input_series)                
        output    = output[:,  -1, :]       # get last state                        
        output    = self.fc(output) 
        output    = output.view(-1, 5, 1)   # reorginize output        
                        
        return output
    
    
class SecondModel(nn.Module):
    def __init__(self):
        super(SecondModel, self).__init__()        
        self.GRU      = nn.GRU(input_size=3, hidden_size=32, num_layers=2, batch_first=True)        
        
    def forward(self, input_series):
                
        output, h = self.GRU(input_series)                                        
        return output

Checking the gradient of the first model gives True (zero gradients):

model = MyModel()
x     = torch.rand([2, 10, 3])
y     = model(x)
y.retain_grad()  
y[:, -1].sum().backward()
print(torch.allclose(y.grad[:, :-1], torch.tensor(0.)))  # gradients w.r.t previous outputs are zeroes

Checking the gradient of the second model also gives True (zero gradients):

model = SecondModel()
x     = torch.rand([2, 10, 3])
y     = model(x)
y.retain_grad()  
y[:, -1].sum().backward()
print(torch.allclose(y.grad[:, :-1], torch.tensor(0.)))  # gradients w.r.t previous outputs are zeroes

According to the answer here:

Do linear layer after GRU saved the sequence output order?

the second model (with just GRU layer) need to give non zero gradients.

  1. What am I missing ?
  2. When will we get zero or non-zero gradients ?

Solution

  • The value of y.grad[:, :-1] theoretically shouldn't be zeroes, but here they are because y[:, :-1] doesn't seem to refer to the same tensor objects used to compute y[:, -1] in the GRU implementation. As an illustration, a simple 1-layer GRU implementation looks like

    import torch
    import torch.nn as nn
    
    class GRU(nn.Module):
        def __init__(self, input_size, hidden_size):
            super().__init__()
            self.lin_r = nn.Linear(input_size + hidden_size, hidden_size)
            self.lin_z = nn.Linear(input_size + hidden_size, hidden_size)
            self.lin_in = nn.Linear(input_size, hidden_size)
            self.lin_hn = nn.Linear(hidden_size, hidden_size)
            self.hidden_size = hidden_size
    
        def forward(self, x):
            bsz, len_, in_ = x.shape
            h = torch.zeros([bsz, self.hidden_size])
            hs = []
            for i in range(len_):
                r = self.lin_r(torch.cat([x[:, i], h], dim=-1)).sigmoid()
                z = self.lin_z(torch.cat([x[:, i], h], dim=-1)).sigmoid()
                n = (self.lin_in(x[:, i]) + r * self.lin_hn(h)).tanh()
                h = (1.-z)*n + z*h
                hs.append(h)
    
            # Return the output both as a single tensor and as a list of
            # tensors actually used in computing the hidden vectors
            return torch.stack(hs, dim=1), hs
    

    Then, we have

    model = GRU(input_size=3, hidden_size=32)
    x = torch.rand([2, 10, 3])
    y, hs = model(x)
    y.retain_grad()
    for h in hs:
        h.retain_grad()
    y[:, -1].sum().backward()
    print(torch.allclose(y.grad[:, -1], torch.tensor(0.)))  # False, as expected (sanity check)
    print(torch.allclose(y.grad[:, :-1], torch.tensor(0.)))  # True, unexpected
    print(any(torch.allclose(h.grad, torch.tensor(0.)) for h in hs))  # False, as expected
    

    It appears PyTorch computes the gradients w.r.t all tensors in hs as expected but not those w.r.t y.

    So, to answer your question:

    1. I don't think you miss anything. The linked answer is just not quite right as it incorrectly assumes PyTorch would compute y.grad as expected.
    2. The theory given as a comment in the linked answer is still right, but not quite complete: gradient is always zero iff the input doesn't matter.