Search code examples
pythonnumpymatrixtorch

Pytorch/NumPy batched submatrix indexing


There's a single source (square) matrix L of shape (N, N)

import torch as pt
import numpy as np

N = 4
L = pt.arange(N*N).reshape(N, N)  # or np.arange(N*N).reshape(N, N)
L = tensor([[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11],
            [12, 13, 14, 15]])

and a matrix (vector of vectors) of boolean masks m of shape (K, N) according to which I'd like to extract submatrices from L.

K = 3
m = tensor([[ True,  True, False, False],
            [False,  True,  True, False],
            [False,  True, False,  True]])

I know how to extract a single submatrix using a single mask vector by calling L[m[i]][:, m[i]] for any i. So, for example, for i=0, we'd get

tensor([[ 0,  1],
        [ 4,  5]])

but I need to perform the operation along the entire "batch" dimension. The end result I'm looking for then could be achieved by

res = []
for i in range(K):
    res.append(L[m[i]][:, m[i]])
output = pt.stack(res)

however, I hope there is a better solution excluding the for loop. I realize that the for loop solution itself would crash if the sum of m along the last dimension (dim/axis=1) wasn't constant, but if I can guarantee that it is, is there a better solution? If there isn't, would changing the selector representation help? I chose boolean masks for convenience, but I prefer better performance.


Solution

  • Notice that you can get the first square by indexing together with broadcasting:

    r = torch.tensor([0,1])
    L[r[:,None], r]
    

    output:

    tensor([[0, 1],
            [4, 5]])
    

    The same principle can be applied to the second square:

    r = torch.tensor([1,2])
    L[r[:,None], r]
    

    output:

    tensor([[ 5,  6],
            [ 9, 10]])
    

    In combination you get:

    i = torch.tensor([[0, 1], [1, 2]])
    L[i[:,:,None], i[:,None]]
    

    output:

    tensor([[[ 0,  4],
             [ 1,  5]],
    
            [[ 5,  9],
             [ 6, 10]]])
    

    All 3 squares:

    i = torch.tensor([
        [0, 1],
        [1, 2],
        [1, 3],
    ])
    L[i[:,:,None], i[:,None]]
    

    output:

    tensor([[[ 0,  1],
             [ 4,  5]],
    
            [[ 5,  6],
             [ 9, 10]],
    
            [[ 5,  7],
             [13, 15]]])
    

    to summarize, I would suggest using indices instead of a mask.