I have a input tensor which has zero padding at the start and then a sequence of values. So something like:
x = torch.tensor([[0, 2, 8, 12],
[0, 0, 6, 3]])
What I need is another tensor having same shape and retaining 0's for the padding and an increasing sequence for the rest of the numbers. So my output tensor should be:
y = ([[0, 1, 2, 3],
[0, 0, 1, 2]])
I tried something like:
MAX_SEQ=4
seq_start = np.nonzero(x)
start = seq_start[0][0]
pos_id = torch.cat((torch.from_numpy(np.zeros(start, dtype=int)).to(device), torch.arange(1, MAX_SEQ-start+1).to(device)), 0)
print(pos_id)
This works if the tensor is 1 dimensional but needs additional logic to handle it for 2-D shape. This can be done as np.nonzeros returns a tuple and we could probably loop thru' those tuples updating a counter or something. However I am sure there must be a simple tensor operation which should do this in 1-2 lines of code and also perhaps more effectively.
Help appreciated
A possible solution in three small steps:
Find the index of the first non zero element for each row. This can be done with a trick explained here (adapted here for non-binary tensors).
> idx = torch.arange(x.shape[1], 0, -1)
tensor([4, 3, 2, 1])
> xbin = torch.where(x == 0, 0, 1)
tensor([[0, 1, 1, 1],
[0, 0, 1, 1]])
> xbin*idx
tensor([[0, 3, 2, 1],
[0, 0, 2, 1]])
> indices = torch.argmax(xbin*idx, dim=1, keepdim=True)
tensor([[1],
[2]])
Create an arangement for the resulting tensor (without padding). This can be done by applying torch.repeat
and torch.view
on a torch.arange call
:
> rows, cols = x.shape
> seq = torch.arange(1, cols+1).repeat(1, rows).view(-1, cols)
tensor([[1, 2, 3, 4],
[1, 2, 3, 4]])
Lastly - here's the trick! - we substract the index of the first non-zero element with the arangement, for each row. Then we mask the padding values and replace them with zeros:
> pos_id = seq - indices
tensor([[ 0, 1, 2, 3],
[-1, 0, 1, 2]])
> mask = indices > seq - 1
tensor([[ True, False, False, False],
[ True, True, False, False]])
> pos_id[mask] = 0
tensor([[0, 1, 2, 3],
[0, 0, 1, 2]])