Search code examples
pythonreplacepytorchtensor

Pytorch tensor - randomly replace values that meet condition


I have a Pytorch tensor mask of dimensions,

torch.Size([8, 24, 24])

with unique values,

> torch.unique(mask, return_counts=True)
(tensor([0, 1, 2]), tensor([2093, 1054, 1461]))

I wish to randomly replace the number of 2s to 0s, such that the unique values and counts in the tensor become,

> torch.unique(mask, return_counts=True)
(tensor([0, 1, 2]), tensor([2500, 1054, 1054]))

I have tried using torch.where to no success. How can this be achieved?


Solution

  • One of the possible solutions is through flattening via view and numpy.random.choice:

    from numpy.random import choice
    
    idx = torch.where(mask.view(-1) == 2)[0]  # get all indicies of 2 in flat tensor
    
    num_to_change = 2500 - 2093 # as follows from example above
    
    idx_to_change = choice(idx, size=num_to_change, replace=False)
    
    mask.view(-1)[idx_to_change] = 0