Search code examples
pythonpytorchruntimetensorlogical-operators

Pytorch logical indexing in a 3D tensor


I have three tensors, a with size(n,4,m), b with size(n,4,1), c with size (n,k,m). a contains batch of node features I want to get, b indicates valid node index (invalid index is masked by 999), c is the batch of node feature set to be exact. n and k are for number of batch and number of nodes respectively.

The goal is to replace valid node features of a with corresponding node features of c with the index value sorted in b

So far, I use a nested for loop to implement it

import torch
n=4
a = torch.arange(n*4*4).view(n,4,4)
value_c = torch.zeros(n,6,4)
b=torch.randint(0,3,(n,4,1))
b[0,1:]=999
b[2,2:] = 999
for i in range(n):
    for j in range(4):
        if b[i,j]<999:
            a[i,j]=value_c[i,b[i,j].long()]

But it is really slow for a large dataset. Is there any ways to speed up it (e.g. with logical indexing)?


Solution

  • Sure. For starters, I suggest masking invalid indices with nan or inf or some special value; using a specific integer is just asking for hard-to-catch issues as you scale the size of your data up. This does give the complication that b will need to be of type float (so it can store nan values) and you'll have to cast it to long before using it to index. Personally I think this is worth the indexing safety but you can do as you want.

    To use list-style indexing, we'll want to unravel indices into a into a 1D tensor for each dimension of a.

    i = torch.arange(n).unsqueeze(1).expand(n,4).reshape(-1)  # something like [0,0,0,0,1,1,1,1 ...]
    j = torch.arange(4).unsqueeze(0).expand(n,4).reshape(-1)  # something like [0,1,2,3,0,1,2,3,...]
    
    k = b[i,j].squeeze() # assemble desired indices of c into 1D tensor as well
    

    Now i,j, and k each contain nx4 indices, which is roughly the number of elements you want to replace. Now lets re-index each tensor once to remove all of the invalid indices.

    valid = torch.where(torch.isnan(k),0,1).nonzero().squeeze().long()
    k = k[valid]
    i = i[valid]
    j = j[valid]
    

    And now we're ready to index.

    a[i,j,:] = c[i,k,:]
    

    You may have to do a bit of type-casting to get everything to work out (a and c for instance need to have the same type.)