I have a tensor of shape (5 * n, n)
and I basically want to extract the first 5 elements from the first 5 rows in the first column, then move over 1 column and extract the next 5 rows for the 2nd column, then move over etc. It's kind of like torch.diagonal but works for nonsymmetric tensors and it can be assumed the tensors have proper dimensions for this to work every time.
For example, if my input tensor was this:
>>> t = torch.arange(45).reshape(15, 3)
>>> t
tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17],
[18, 19, 20],
[21, 22, 23],
[24, 25, 26],
[27, 28, 29],
[30, 31, 32],
[33, 34, 35],
[36, 37, 38],
[39, 40, 41],
[42, 43, 44]])
then I would want some way to get
out = tensor([0, 3, 6, 9, 12, 16, 19, 22, 25, 28, 32, 35, 38, 41, 44])
I prefer that I don't have to use a loop because I will be doing this thousand times for some pretty big inputs so I feel like that would be very inefficient.
I got to an efficient solution using slicing:
import torch
from itertools import tee
def pairwise(iterable):
# pairwise('ABCDEFG') --> AB BC CD DE EF FG
a, b = tee(iterable)
next(b, None)
return zip(a, b)
def generalized_diagonal(t):
ratio = int(max(t.shape) / min(t.shape))
indexes = ( (i, i*ratio) for i in range(min(t.shape)+1) )
parts = [t[y0:y1, x0:x1] for (x0, y0), (x1, y1) in pairwise(indexes)]
return torch.flatten(torch.stack(parts, dim=0))
t = torch.arange(45).reshape(15, 3)
print(generalized_diagonal(t))
Output:
tensor([ 0, 3, 6, 9, 12, 16, 19, 22, 25, 28, 32, 35, 38, 41, 44])