Search code examples
pytorchstride

Swapping the batch axis has effect on the performance in pytorch?


I know that usually the batch dimension is axis zero, and I imagine this has a reason: The underlying memory for each item in the batch is contiguous.

My model calls a function that becomes simpler if I have another dimension in the first axis, so that I can use x[k] instead of x[:, k].

Results from arithmetic operations seems to keep the same memory layout

x = torch.ones(2,3,4).transpose(0,1)
y = torch.ones_like(x)
u = (x + 1)
v = (x + y)
print(x.stride(), u.stride(), v.stride())

When I create additional variables I am creating them with torch.zeros and then transposing, so that the largest stride goes to the axis 1, as well.

e.g.

a,b,c = torch.zeros(
         (3, x.shape[1], ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).transpose(1,2)

Will create three tensors with the same batch size x.shape[1]. In terms of memory locality it would make any difference to have

a,b,c = torch.zeros(
  (x.shape[1], 3, ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).permute(1,2,0, ...)

instead.

Should I care about this at all?


Solution

  • TLDR; Slices seemingly contain less information... but in fact share the identical storage buffer with the original tensor. Since permute doesn't affect the underlying memory layout, both operations are essentially equivalent.


    Those two are essentially the same, the underlying data storage buffer is kept the same, only the metadata i.e. how you interact with that buffer (strides and shape) changes.

    Let us look at a simple example:

    >>> x = torch.ones(2,3,4).transpose(0,1)
    >>> x_ptr = x.data_ptr()
    
    >>> x.shape, x.stride(), x_ptr
    (3, 2, 4), (4, 12, 1), 94674451667072
    

    We have kept the data pointer for our 'base' tensor in x_ptr:

    1. Slicing on the second axis:

      >>> y = x[:, 0]
      
      >>> y.shape, y.stride(), x_ptr == y.data_ptr()
      (3, 4), (4, 1), True
      

      As you can see, x and x[:, k] shared the same storage.

    2. Permuting the first two axes then slicing on the first one:

      >>> z = x.permute(1, 0, 2)[0]
      
      >>> z.shape, z.stride(), x_ptr == z.data_ptr()
      (3, 4), (4, 1), True
      

      Here again, you notice that x.data_ptr is the same as z.data_ptr.


    In fact, you can even go from y to x's representation using torch.as_strided:

    >>> torch.as_strided(y, size=x.shape, stride=x.stride())
    tensor([[[1., 1., 1., 1.],
             [1., 1., 1., 1.]],
    
            [[1., 1., 1., 1.],
             [1., 1., 1., 1.]],
    
            [[1., 1., 1., 1.],
             [1., 1., 1., 1.]]])
    

    Same with z:

    >>> torch.as_strided(z, size=x.shape, stride=x.stride())
    

    Both will return a copy of x because torch.as_strided is allocating memory for the newly created tensor. These two lines were just to illustrate how we can still 'get back' to x from a slice of x, we can recover the apparent content by changing the tensor's metadata.