Search code examples
pytorchautograd

How to see which indices of an input effected an index of output


I want to test my neural network.

For example, given: an input tensor input, a nn.module with some submodules module, an output tensor output,

I want to find which indices of input effected the index (1,2) of output

More specifically, given:

  • Two input matrix of size (12, 12),
  • Operation is matmul
  • Queried index of the output matrix is: (0,0)

the expected output is:

InputMatrix1: (0,0), (0, 1), ..., (0, 11)

InputMatrix2: (0,0), (1, 0), ..., (11, 0)

Maybe visualization is okay.

Is there any method or libraries that can achieve this?


Solution

  • This is easy. You want to look at the non-zeros entries of the grad of InputMatrix1 and InputMatrix2 w.r.t the (0,0) element of the product:

    x = torch.rand((12, 12), requires_grad=True)  # explicitly asking for gradient for this tensor
    y = torch.rand((12, 12), requires_grad=True)  # explicitly asking for gradient for this tensor
    # compute the product using @ operator:
    out = x @ y
    # use back propagation to compute the gradient w.r.t out[0, 0]:
    out[0,0].backward()
    

    Inspect the non-zero elements of the inputs' gradients yield, as expected:

    In []: x.grad.nonzero()
    
    tensor([[ 0,  0],
            [ 0,  1],
            [ 0,  2],
            [ 0,  3],
            [ 0,  4],
            [ 0,  5],
            [ 0,  6],
            [ 0,  7],
            [ 0,  8],
            [ 0,  9],
            [ 0, 10],
            [ 0, 11]])
    
    In []: y.grad.nonzero()
    
    tensor([[ 0,  0],
            [ 1,  0],
            [ 2,  0],
            [ 3,  0],
            [ 4,  0],
            [ 5,  0],
            [ 6,  0],
            [ 7,  0],
            [ 8,  0],
            [ 9,  0],
            [10,  0],
            [11,  0]])