Search code examples
pythondeep-learningpytorchlstmrecurrent-neural-network

How to access the gradients of intermediate outputs during the training loop?


Let's say I have following (relatively) small lstm model:

First, let's create some pseudo input/target data:

import torch

# create pseudo input data (features)
features = torch.rand(size = (64, 24, 3)) # of shape (batch_size, num_time_steps, num_features)

# create pseudo target data
targets = torch.ones(size = (64, 24, 1)) # of shape (batch_size, num_time_steps, num_targets)

# store num. of time steps
num_time_steps = features.shape[1]

Now, let's define a simple lstm model:

# create a simple lstm model with lstm_cell
class SmallModel(torch.nn.Module):
    
    def __init__(self):
        
        super().__init__() # initialize the parent class
    
        # define the layers
        self.lstm_cell = torch.nn.LSTMCell(input_size = features.shape[2], hidden_size = 16)
        self.fc = torch.nn.Linear(in_features = 16, out_features = targets.shape[2])
        
    def forward(self, features):
        
        # initialise states
        hx = torch.randn(64, 16)
        cx = torch.randn(64, 16)
        
        # empty list to collect final preds
        a_s = []
        b_s = []
        c_s = []
        
        for t in range(num_time_steps): # loop through each time step
            
            # select features at the current time step t
            features_t = features[:, t, :]
            
            # forward computation at the current time step t
            hx, cx = self.lstm_cell(features_t, (hx, cx))
            out_t = torch.relu(self.fc(hx))
            
            # do some computation with the output
            a = out_t * 0.8 + 20
            b = a * 2
            c = b * 0.9
            
            a_s.append(a)
            b_s.append(b)
            c_s.append(c)
            
        a_s = torch.stack(a_s, dim = 1) # of shape (batch_size, num_time_steps, num_targets)
        b_s = torch.stack(b_s, dim = 1)
        c_s = torch.stack(c_s, dim = 1)
        
        return a_s, b_s, c_s

Instantiating model, loss fun. and optimizer:

# instantiate the model
model = SmallModel()

# loss function
loss_fn = torch.nn.MSELoss()

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

Now, during the training loop, I want to print the gradients of the intermediate (a_s.grad, b_s.grad) outputs for each epoch:

# number of epochs
n_epoch = 10

# training loop
for epoch in range(n_epoch): # loop through each epoch
    
    # zero out the grad because pytorch accumulates them
    optimizer.zero_grad()
    
    # make predictions
    a_s, b_s, c_s = model(features)
    
    # retain the gradients of intermediate outputs
    a_s.retain_grad()
    b_s.retain_grad()
    c_s.retain_grad()
    
    # compute loss
    loss = loss_fn(c_s, targets)
    
    # backward computation
    loss.backward()
    
    # print gradients of outpus at each epoch
    print(a_s.grad)
    print(b_s.grad)
    
    # update the weights
    optimizer.step()

But I get the following:

None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None
None

How can I get the actual gradients of the intermediate outputs?


Solution

  • c_s is not a function of a_s and b_s that is the problem.

    In your code:

    loss = func(c_s, *)
    c_s = func(a, b)
    # c_s = func(a_s, b_s) is not true
    

    Hence during backward pass no grad will be calculated for variables a_s and b_s.

    Try this modified forward function to get gradients for a_s and b_s where c_s = func(a_s, b_s):

        def forward(self, features):
            
            # initialise states
            hx = torch.randn(64, 16)
            cx = torch.randn(64, 16)
            
            # empty list to collect final preds
            a_s = []
            b_s = []
            c_s = []
            
            for t in range(num_time_steps): # loop through each time step
                
                # select features at the current time step t
                features_t = features[:, t, :]
                
                # forward computation at the current time step t
                hx, cx = self.lstm_cell(features_t, (hx, cx))
                out_t = torch.relu(self.fc(hx))
                
                # do some computation with the output
                a = out_t * 0.8 + 20
                # b = a * 2
                # c = b * 0.9
                
                a_s.append(a)
                # b_s.append(b)
                # c_s.append(c)
                
            a_s = torch.stack(a_s, dim = 1) # of shape (batch_size, num_time_steps, num_targets)
            ##########################################
            ## c_s = func(a_s, b_s)
            ##########################################
            b_s = a_s * 2
            c_s = b_s * 0.9
            ##########################################
            ##########################################
            
            
            return a_s, b_s, c_s