I want to test my neural network.
For example, given: an input tensor input
, a nn.module
with some submodules module
, an output tensor output
,
I want to find which indices of input
effected the index (1,2) of output
More specifically, given:
matmul
the expected output is:
InputMatrix1: (0,0), (0, 1), ..., (0, 11)
InputMatrix2: (0,0), (1, 0), ..., (11, 0)
Maybe visualization is okay.
Is there any method or libraries that can achieve this?
This is easy. You want to look at the non-zeros entries of the grad
of InputMatrix1
and InputMatrix2
w.r.t the (0,0)
element of the product:
x = torch.rand((12, 12), requires_grad=True) # explicitly asking for gradient for this tensor
y = torch.rand((12, 12), requires_grad=True) # explicitly asking for gradient for this tensor
# compute the product using @ operator:
out = x @ y
# use back propagation to compute the gradient w.r.t out[0, 0]:
out[0,0].backward()
Inspect the non-zero elements of the inputs' gradients yield, as expected:
In []: x.grad.nonzero()
tensor([[ 0, 0], [ 0, 1], [ 0, 2], [ 0, 3], [ 0, 4], [ 0, 5], [ 0, 6], [ 0, 7], [ 0, 8], [ 0, 9], [ 0, 10], [ 0, 11]])
In []: y.grad.nonzero()
tensor([[ 0, 0], [ 1, 0], [ 2, 0], [ 3, 0], [ 4, 0], [ 5, 0], [ 6, 0], [ 7, 0], [ 8, 0], [ 9, 0], [10, 0], [11, 0]])