Search code examples
pythonpytorchtensor

Get the diagonal of a non symmetric tensor to perfectly end up at the bottom right in pytorch


I have a tensor of shape (5 * n, n) and I basically want to extract the first 5 elements from the first 5 rows in the first column, then move over 1 column and extract the next 5 rows for the 2nd column, then move over etc. It's kind of like torch.diagonal but works for nonsymmetric tensors and it can be assumed the tensors have proper dimensions for this to work every time.

For example, if my input tensor was this:

>>> t = torch.arange(45).reshape(15, 3)
>>> t
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [21, 22, 23],
        [24, 25, 26],
        [27, 28, 29],
        [30, 31, 32],
        [33, 34, 35],
        [36, 37, 38],
        [39, 40, 41],
        [42, 43, 44]])

then I would want some way to get

out = tensor([0, 3, 6, 9, 12, 16, 19, 22, 25, 28, 32, 35, 38, 41, 44])

I prefer that I don't have to use a loop because I will be doing this thousand times for some pretty big inputs so I feel like that would be very inefficient.


Solution

  • I got to an efficient solution using slicing:

    import torch
    from itertools import tee
    
    def pairwise(iterable):
        # pairwise('ABCDEFG') --> AB BC CD DE EF FG
        a, b = tee(iterable)
        next(b, None)
        return zip(a, b)
    
    def generalized_diagonal(t):
        ratio = int(max(t.shape) / min(t.shape))
        indexes = ( (i, i*ratio) for i in range(min(t.shape)+1) )
        parts = [t[y0:y1, x0:x1] for (x0, y0), (x1, y1) in pairwise(indexes)]
        return torch.flatten(torch.stack(parts, dim=0))
    
    t = torch.arange(45).reshape(15, 3)
    print(generalized_diagonal(t))
    

    Output:

    tensor([ 0,  3,  6,  9, 12, 16, 19, 22, 25, 28, 32, 35, 38, 41, 44])