Search code examples
pythonpytorchtensor

Pytorch meshgrid


I'm currently using torch 1.12 for my thesis on neural A star algorithm and i'm not quite sure what meshgrid does. For example why

x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
torch.meshgrid(x, y)

return

tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
tensor([[4, 5, 6], [4, 5, 6], [4, 5, 6]])

and not

tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
tensor([[4, 4, 4], [5, 5, 5], [6, 6, 6]])

Solution

  • It all depends on the indexing mode you use: ij or xy (default is ij):

    >>> torch.meshgrid(x, y, indexing='xy')
    (tensor([[1, 2, 3],
             [1, 2, 3],
             [1, 2, 3]]),
     tensor([[4, 4, 4],
             [5, 5, 5],
             [6, 6, 6]]))
    

    Here is a visualization of the two mesh grids:

    enter image description here


    If you are not interested in this structure and prefer the expected result, you can use a combination of stack, transpose and repeat:

    torch.hstack((x[:,None], y[:,None])).T[...,None].repeat(1,1,3)
    tensor([[[1, 1, 1],
             [2, 2, 2],
             [3, 3, 3]],
    
            [[4, 4, 4],
             [5, 5, 5],
             [6, 6, 6]]])