Search code examples
pythonpytorchmatrix-multiplicationautograd

How to multiply 2x3x3x3 matrix by 2x3 matrix to get 2x3 matrix


I am trying to compute some derivatives of neural network outputs. To be precise I need the jacobian matrix of the function that is represented by the neural network and the second derivative of the function with respect to its inputs.

I want to multiply the derivative of the jacobian with a vector of same size as the input for every sample.

I got the result tht I need with this implementation:

import torch

x_1 = torch.tensor([[1.,1.,1.]], requires_grad=True)
x_2 = torch.tensor([[2.,2.,2.]], requires_grad=True)
# Input to the network with dim 2x3 --> 2 Samples 3 Feature
x = torch.cat((x_1,x_2),dim=0)

def calculation(x):
    c = torch.tensor([[1,2,3],[4,5,6],[7,8,9]]).float()
    return (x@c)**2

c = torch.tensor([[1,2,3],[4,5,6],[7,8,9]]).float()

#output of the network with dimension 2x3 --> 3 outputs per Sample
y = calculation(x)

#Calculation of my jacobian with dimension 2x3x3 (one for each sample)
g = torch.autograd.functional.jacobian(calculation,(x),create_graph=True)
jacobian_summarized = torch.sum(g, dim=0).transpose(0,1)
#Calculation of my second order derivative 2x3x3x3 (On Tensor for each Sample)
gg = torch.autograd.functional.jacobian(lambda x: torch.sum(torch.autograd.functional.jacobian(calculation, (x), create_graph=True), dim=0).transpose(0,1),
            (x),
            create_graph=True)
second_order_derivative = torch.sum(gg, dim=0).transpose(1,2).transpose(0,1)

print('x:', x)
print('c:',c)
print('y:',y)
print('First Order Derivative:',jacobian_summarized)
print('Second Order Derivative:',second_order_derivative)

# Multiplication with for loop
result = torch.empty(0)
for i in range(y.shape[0]):
    result_row = torch.empty(0)
    for ii in range(y.shape[1]):
        result_value = (x[i].unsqueeze(0))@second_order_derivative[i][ii]@(x[i].unsqueeze(0).T)
        result_row = torch.cat((result_row, result_value), dim=1)
    result = torch.cat((result, result_row))
print(result)

I would like to know if there is a way to get the same result of the multiplication without having to use 2 for loops but rather some simple multiplication of the matrices


Solution

  • It seems like you're looking for einsum.

    Should be something like:

    result = torch.einsum('bi,bijk,bk->bj', x, second_order_derivative, x)