Search code examples
pythonpytorchvectorizationtensorzero-padding

Clip or threshold a tensor using condition and zero pad the result in PyTorch


let's say I have a tensor like this

w = [[0.1, 0.7, 0.7, 0.8, 0.3],
    [0.3, 0.2, 0.9, 0.1, 0.5],
    [0.1, 0.4, 0.8, 0.3, 0.4]]

Now I want to eliminate certain values base on some condition (for example greater than 0.5 or not)

w = [[0.1, 0.3],
     [0.3, 0.2, 0.1],
     [0.1, 0.4, 0.3, 0.4]]

Then pad it to equal length:

w = [[0.1, 0.3, 0, 0],
     [0.3, 0.2, 0.1, 0],
     [0.1, 0.4, 0.3, 0.4]]

and this is how I implemented it in pytorch:

w = torch.rand(3, 5)
condition = w <= 0.5
w = [w[i][condition[i]] for i in range(3)]
w = torch.nn.utils.rnn.pad_sequence(w)

But apparently this is going to be extremely slow, mainly because of the list comprehension. is there any better way to do it?


Solution

  • Here's one straightforward way using boolean masking, tensor splitting, and then eventually padding the splitted tensors using torch.nn.utils.rnn.pad_sequence(...).

    # input tensor to work with
    In [213]: w 
    Out[213]: 
    tensor([[0.1000, 0.7000, 0.7000, 0.8000, 0.3000],
            [0.3000, 0.2000, 0.9000, 0.1000, 0.5000],
            [0.1000, 0.4000, 0.8000, 0.3000, 0.4000]])
    
    # values above this should be clipped from the input tensor
    In [214]: clip_value = 0.5 
    
    # generate a boolean mask that satisfies the condition
    In [215]: boolean_mask = (w <= clip_value) 
    
    # we need to sum the mask along axis 1 (needed for splitting)
    In [216]: summed_mask = boolean_mask.sum(dim=1) 
    
    # a sequence of splitted tensors
    In [217]: splitted_tensors = torch.split(w[boolean_mask], summed_mask.tolist())  
    
    # finally pad them along dimension 1 (or axis 1)
    In [219]: torch.nn.utils.rnn.pad_sequence(splitted_tensors, 1) 
    Out[219]: 
    tensor([[0.1000, 0.3000, 0.0000, 0.0000],
            [0.3000, 0.2000, 0.1000, 0.5000],
            [0.1000, 0.4000, 0.3000, 0.4000]])
    

    A short note on efficiency: Using torch.split() is super efficient since it returns the splitted tensors as a view of the original tensor (i.e. no copy is made).