Search code examples
pythonpytorchtensortorch

Tensor manipulation - creating a positional tensor from a given tensor


I have a input tensor which has zero padding at the start and then a sequence of values. So something like:

x = torch.tensor([[0, 2, 8, 12],
                  [0, 0, 6, 3]])

What I need is another tensor having same shape and retaining 0's for the padding and an increasing sequence for the rest of the numbers. So my output tensor should be:

y = ([[0, 1, 2, 3],
      [0, 0, 1, 2]])

I tried something like:

MAX_SEQ=4
seq_start = np.nonzero(x)
start = seq_start[0][0]
pos_id = torch.cat((torch.from_numpy(np.zeros(start, dtype=int)).to(device), torch.arange(1, MAX_SEQ-start+1).to(device)), 0)
print(pos_id)

This works if the tensor is 1 dimensional but needs additional logic to handle it for 2-D shape. This can be done as np.nonzeros returns a tuple and we could probably loop thru' those tuples updating a counter or something. However I am sure there must be a simple tensor operation which should do this in 1-2 lines of code and also perhaps more effectively.

Help appreciated


Solution

  • A possible solution in three small steps:

    1. Find the index of the first non zero element for each row. This can be done with a trick explained here (adapted here for non-binary tensors).

      > idx = torch.arange(x.shape[1], 0, -1)
      tensor([4, 3, 2, 1])
      
      > xbin = torch.where(x == 0, 0, 1)
      tensor([[0, 1, 1, 1],
              [0, 0, 1, 1]])
      
      > xbin*idx
      tensor([[0, 3, 2, 1],
              [0, 0, 2, 1]])
      
      > indices = torch.argmax(xbin*idx, dim=1, keepdim=True)
      tensor([[1],
              [2]])
      
    2. Create an arangement for the resulting tensor (without padding). This can be done by applying torch.repeat and torch.view on a torch.arange call:

      > rows, cols = x.shape
      > seq = torch.arange(1, cols+1).repeat(1, rows).view(-1, cols)
      tensor([[1, 2, 3, 4],
              [1, 2, 3, 4]])
      
    3. Lastly - here's the trick! - we substract the index of the first non-zero element with the arangement, for each row. Then we mask the padding values and replace them with zeros:

      > pos_id = seq - indices
      tensor([[ 0,  1,  2,  3],
              [-1,  0,  1,  2]])
      
      > mask = indices > seq - 1
      tensor([[ True, False, False, False],
              [ True,  True, False, False]])
      
      > pos_id[mask] = 0
      tensor([[0, 1, 2, 3],
              [0, 0, 1, 2]])