Search code examples
pythonpytorchgradienttensorgradient-descent

Only Tensors of floating point and complex dtype can require gradients


I am receiving the following error when I run a convolution operation inside a torch.no_grad() context:

RuntimeError: Only Tensors of floating-point and complex dtype can require gradients.

import torch.nn as nn
import torch

with torch.no_grad():
    ker_t = torch.tensor([[1, -1], [-1, 1]])
    in_t = torch.tensor([[14, 7, 6, 2,], [4 ,8 ,11 ,1], [3, 5, 9 ,10], [12, 15, 16, 13]])
    print(in_t.shape)
    in_t = torch.unsqueeze(in_t,0)
    in_t = torch.unsqueeze(in_t,0)
    print(in_t.shape)

    conv = nn.Conv2d(1, 1, kernel_size=2,stride=2,dtype=torch.long)
    conv.weight[:] = ker_t
    conv(in_t)

Now, I am sure if I turn my input into floats, this message will go away, but I want to work in integers.

But I was under the impression that if I am in a with torch.no_grad() context it should turn off the need for gradients.


Solution

  • The need for gradients comes from nn.Conv2d when it registers the weights of the convolution layer.
    However, if you are only after the forward pass, you do not need to use a convolution layer: you can use the underlying convolution function:

    import torch.nn.functional as nnf
    
    ker_t = torch.tensor([[1, -1] ,[-1, 1]])[None, None, ...]
    in_t = torch.tensor([ [14, 7, 6, 2,] , [4 ,8 ,11 ,1], [3, 5, 9 ,10], [12, 15, 16, 13] ])[None, None, ...]
    out = nnf.conv2d(in_t, ker_t, stride=2)
    

    Will give you this output:

    tensor([[[[11, -6],
              [ 1, -4]]]])