I have three tensors, a
with size(n,4,m), b
with size(n,4,1), c
with size (n,k,m). a
contains batch of node features I want to get, b
indicates valid node index (invalid index is masked by 999), c
is the batch of node feature set to be exact. n and k are for number of batch and number of nodes respectively.
The goal is to replace valid node features of a
with corresponding node features of c
with the index value sorted in b
So far, I use a nested for loop to implement it
import torch
n=4
a = torch.arange(n*4*4).view(n,4,4)
value_c = torch.zeros(n,6,4)
b=torch.randint(0,3,(n,4,1))
b[0,1:]=999
b[2,2:] = 999
for i in range(n):
for j in range(4):
if b[i,j]<999:
a[i,j]=value_c[i,b[i,j].long()]
But it is really slow for a large dataset. Is there any ways to speed up it (e.g. with logical indexing)?
Sure. For starters, I suggest masking invalid indices with nan
or inf
or some special value; using a specific integer is just asking for hard-to-catch issues as you scale the size of your data up. This does give the complication that b
will need to be of type float
(so it can store nan
values) and you'll have to cast it to long
before using it to index. Personally I think this is worth the indexing safety but you can do as you want.
To use list-style indexing, we'll want to unravel indices into a
into a 1D tensor for each dimension of a
.
i = torch.arange(n).unsqueeze(1).expand(n,4).reshape(-1) # something like [0,0,0,0,1,1,1,1 ...]
j = torch.arange(4).unsqueeze(0).expand(n,4).reshape(-1) # something like [0,1,2,3,0,1,2,3,...]
k = b[i,j].squeeze() # assemble desired indices of c into 1D tensor as well
Now i
,j
, and k
each contain nx4
indices, which is roughly the number of elements you want to replace. Now lets re-index each tensor once to remove all of the invalid indices.
valid = torch.where(torch.isnan(k),0,1).nonzero().squeeze().long()
k = k[valid]
i = i[valid]
j = j[valid]
And now we're ready to index.
a[i,j,:] = c[i,k,:]
You may have to do a bit of type-casting to get everything to work out (a
and c
for instance need to have the same type.)