Search code examples
pythonpytorchautograd

Why does the jacobian of the metric tensor give zero?


I am trying to compute the derivatives of the metric tensor given as follows:

Metric tensor in spherical coordinates

As part of this, I am using PyTorch to compute the jacobian of the metric. Here is my code so far:

# initial coordinates
r0, theta0, phi0 = (3., torch.pi/2, 0.1)
coord = torch.tensor([r0, theta0, phi0], requires_grad=True)
print("r, theta, phi coordinates: ",coord.data)

def metric_tensor(coords):
    r = coords[0]
    theta = coords[1]
    phi = coords[2]
    return torch.tensor([
        [1., 0., 0.],
        [0., r ** 2, 0.],
        [0., 0., r ** 2 * torch.sin(theta) ** 2]
    ])

jacobian = torch.autograd.functional.jacobian(metric_tensor, coord, create_graph=True)

For reasons I don't understand, the jacobian always returns zero, even though the derivatives of the metric shouldn't all be zero. Could anyone point me to what the issue may be?


Solution

  • torch.pi is not exist, so you should implement it by yourself. Moreover, I think the problem is that the calculation is made explicitly on the created tensor. This works for me:

    import torch
    
    # initial coordinates
    torch.pi = torch.acos(torch.zeros(1)).item() * 2
    r0, theta0, phi0 = (3., torch.pi / 2, 0.1)
    coord = torch.tensor([r0, theta0, phi0], requires_grad=True)
    print("r, theta, phi coordinates: ", coord.data)
    
    
    
    def metric_tensor(coords):
        r = coords[0]
        theta = coords[1]
        phi = coords[2]
        res = torch.zeros(3, 3)
        res[0, 0] = 1.
        res[1, 1] = r ** 2
        res[2, 2] = (r ** 2) * (torch.sin(theta) ** 2)
        return res
    
    
    jacobian = torch.autograd.functional.jacobian(metric_tensor, coord, create_graph=True)
    print(jacobian)
    

    The result is this:

    r, theta, phi coordinates:  tensor([3.0000, 1.5708, 0.1000])
    tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
             [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
             [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
    
            [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
             [ 6.0000e+00,  0.0000e+00,  0.0000e+00],
             [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
    
            [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
             [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
             [ 6.0000e+00, -7.8681e-07,  0.0000e+00]]], grad_fn=<ViewBackward>)