Search code examples
pythontorch

Replace all zeros with last non-zero value in torch


Is there any efficient way to replace all zeros in a tensor with the last non-zero value in torch?

For example if I had the tensor:

tensor([[1, 0, 0, 4, 0, 5, 0, 0],
        [0, 3, 0, 6, 0, 0, 8, 0]])

The output should be:

tensor([[1, 1, 1, 4, 4, 5, 5, 5],
        [0, 3, 3, 6, 6, 6, 8, 8]])

I currently have the following code:

def replace_zeros_with_prev_nonzero(tensor):
    output = tensor.clone()
    for i in range(len(output)):
        prev_value = 0
        for j in range(len(tensor[i])):
            if tensor[i,j] == 0:
                output[i,j] = prev_value
            else:
                prev_value = tensor[i,j].item()      
    return output

But it feels though a bit clunky and I'm sure there has to be a better way to do this. So is it possible to write it in fewer lines, or better yet parallelise the operation without treating the tensors as arrays?


Solution

  • You can remove one of the loops by vectorising over 1st dimension.

    def replace_zeros_with_prev_nonzero(tensor):
        output = tensor.clone()
        for i in range(1, tensor.shape[1]):
            mask = tensor[:, i] == 0
            output[mask, i] = output[mask, i-1]
    
        return output
    

    output[mask, i] = output[mask, i-1] replaces 0 with the previous value (which itself will be replaced if 0 originally except for 0th index).