I want to remove for loop
in pytorch.
I wrote the code as below, but it's not cool. And it doesn't use the GPU well either.
for i in range(idx.shape[0]):
for j in range(idx.shape[2]):
for k in range(idx.shape[3]):
for x in range(idx.shape[4]):
for y in range(idx.shape[5]):
default_mask[i,idx[i,0,j,k,x,y],j,k,x,y] = 0
And my question is:
You can use torch.scatter_
to scatter zeros by indexing default_mask
with idx
along dim=1
. In other words, this correspond to the following in pseudo-code:
for i,j,k,x,y in dims:
default_mask[i,idx[i,0,j,k,x,y],j,k,x,y] = 0
Since idx
already has the same number of dimensions as default_mask
, you just ned to expand dim=1
to the correct number before applying the scatter function.
>>> x.scatter_(dim=1, index=idx.expand_as(x), value=0)
The above operation is in-place