Is there any efficient way to replace all zeros in a tensor with the last non-zero value in torch?
For example if I had the tensor:
tensor([[1, 0, 0, 4, 0, 5, 0, 0],
[0, 3, 0, 6, 0, 0, 8, 0]])
The output should be:
tensor([[1, 1, 1, 4, 4, 5, 5, 5],
[0, 3, 3, 6, 6, 6, 8, 8]])
I currently have the following code:
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(len(output)):
prev_value = 0
for j in range(len(tensor[i])):
if tensor[i,j] == 0:
output[i,j] = prev_value
else:
prev_value = tensor[i,j].item()
return output
But it feels though a bit clunky and I'm sure there has to be a better way to do this. So is it possible to write it in fewer lines, or better yet parallelise the operation without treating the tensors as arrays?
You can remove one of the loops by vectorising over 1st dimension.
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(1, tensor.shape[1]):
mask = tensor[:, i] == 0
output[mask, i] = output[mask, i-1]
return output
output[mask, i] = output[mask, i-1]
replaces 0 with the previous value (which itself will be replaced if 0 originally except for 0th index).