Search code examples
inputpytorchoutputgradientmnist

How to get the output gradient w.r.t input


I have some problem with getting the output gradient of input. It is simple mnist model.

for num,(sample_img, sample_label) in enumerate(mnist_test):
    if num == 1:
        break

    sample_img = sample_img.to(device)
    sample_img.requires_grad = True
    prediction = model(sample_img.unsqueeze(dim=0))
    cost = criterion(prediction, torch.tensor([sample_label]).to(device))
    optimizer.zero_grad()
    cost.backward()
    print(sample_label)
    print(sample_img.shape)

    plt.imshow(sample_img.detach().cpu().squeeze(),cmap='gray')
    plt.show()

print(sample_img.grad)

sample_img.grad is None


Solution

  • If you need to compute the gradient with respect to the input you can do so by calling sample_img.requires_grad_(), or by setting sample_img.requires_grad = True, as suggested in your comments.

    Here is a small example:

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import matplotlib.pyplot as plt
    
    
    model = nn.Sequential(  # a dummy model
        nn.Conv2d(1, 1, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Flatten()
    )
    
    sample_img = torch.rand(1, 5, 5)  # a dummy input
    sample_label = 0
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-3)
    device = "cpu"
    
    sample_img = sample_img.to(device)
    sample_img.requires_grad = True
    
    prediction = model(sample_img.unsqueeze(dim=0))
    cost = criterion(prediction, torch.tensor([sample_label]).to(device))
    optimizer.zero_grad()
    cost.backward()
    print(sample_label)
    print(sample_img.shape)
    
    plt.imshow(sample_img.detach().cpu().squeeze(), cmap='gray')
    plt.show()
    
    print(sample_img.grad.shape)
    print(sample_img.grad)
    

    Additionally, if you don't need the gradients of the model, you can set their gradient requirements off:

    for param in model.parameters():
        param.requires_grad = False