Search code examples
pythonpytorch

Remove for loop in pytorch


I want to remove for loop in pytorch.

I wrote the code as below, but it's not cool. And it doesn't use the GPU well either.

for i in range(idx.shape[0]):
  for j in range(idx.shape[2]):
    for k in range(idx.shape[3]):
      for x in range(idx.shape[4]):
        for y in range(idx.shape[5]):
          default_mask[i,idx[i,0,j,k,x,y],j,k,x,y] = 0

And my question is:

  1. Please remove the for loop from this code.
  2. Give me a general way to get rid of the for loop.

Solution

  • You can use torch.scatter_ to scatter zeros by indexing default_mask with idx along dim=1. In other words, this correspond to the following in pseudo-code:

    for i,j,k,x,y in dims:
        default_mask[i,idx[i,0,j,k,x,y],j,k,x,y] = 0
    

    Since idx already has the same number of dimensions as default_mask, you just ned to expand dim=1 to the correct number before applying the scatter function.

    >>> x.scatter_(dim=1, index=idx.expand_as(x), value=0)
    

    The above operation is in-place