Search code examples
pythonpytorchtensordiagonal

Extracting blocks from block diagonal PyTorch tensor


I have a tensor of shape (m*n, m*n) and I want to extract a tensor of size (n, m*n) containing the m blocks of size n*n that are on the diagonal. For example:

>>> a
tensor([[1, 2, 0, 0],
        [3, 4, 0, 0],
        [0, 0, 5, 6],
        [0, 0, 7, 8]])

I want to have a function extract(a, m, n) that will output:

>>> extract(a, 2, 2)
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])

I've thought of using some kind of slicing, because the blocks can be expressed by:

>>> for i in range(m):
...     print(a[i*m: i*m + n, i*m: i*m + n])
tensor([[1, 2],
        [3, 4]])
tensor([[5, 6],
        [7, 8]])

Solution

  • You can take advantage of reshape and slicing:

    import torch
    import numpy as np
    
    def extract(a, m, n):
      s=(range(m), np.s_[:], range(m), np.s_[:])  # the slices of the blocks
      a.reshape(m, n, m, n)[s]  # reshaping according to blocks and slicing
      return a.reshape(m*n, n).T  # reshape to desired output format
    

    Example:

    a = torch.arange(36).reshape(6,6)
    a
    tensor([[ 0,  1,  2,  3,  4,  5],
            [ 6,  7,  8,  9, 10, 11],
            [12, 13, 14, 15, 16, 17],
            [18, 19, 20, 21, 22, 23],
            [24, 25, 26, 27, 28, 29],
            [30, 31, 32, 33, 34, 35]])
    
    extract(a, 3, 2)
    
    tensor([[ 0,  6, 14, 20, 28, 34],
            [ 1,  7, 15, 21, 29, 35]])
    
    extract(a, 2, 3)
    
    tensor([[ 0,  6, 12, 21, 27, 33],
            [ 1,  7, 13, 22, 28, 34],
            [ 2,  8, 14, 23, 29, 35]])