Search code examples
pythonpytorchtorch

How do I find and replace the element of the specified index


Let's say I now have two distance matrices up and down, both of size (batch_size,pomo_size,problem_size,problem_size). Suppose batch_size=1,pomo_size=2,problem_size=5.

up = tensor([[[[8, 5, 6, 8, 2],
[2, 9, 7, 7, 0],
[5, 9, 1, 7, 7],
[0, 5, 8, 6, 3],
[5, 9, 1, 0, 2]

[[2, 8, 1, 4, 7],
[0, 0, 2, 2, 7],
[4, 7, 7, 9, 4],
[6, 6, 7, 1, 3],
[3, 9, 9, 7, 2]]]])
down = tensor([[[[6, 1, 7, 9, 1],
[7, 6, 2, 7, 9],
[5, 3, 8, 6, 3],
[0, 8, 6, 3, 3],
[8, 9, 4, 8, 1]

[[7, 2, 0, 6, 0],
[6, 7, 5, 3, 9],
[4, 8, 6, 6, 1],
[9, 3, 2, 2, 5],
[2, 5, 2, 1, 8]]]])

Now we also have a sequence selected_node_list of size (batch_size,pomo_size,seleced_length), assuming selected_length = 8

selected_node_list = torch.tensor([[[0, 1, 2, 3, 4, 1, 3, 2],
                  [1, 0, 3, 4, 0, 2, 4, 3]]])

I want to find the edge represented by the link sequence in down, and then replace the value of the corresponding edge with the value of up. For example, take the first pomo of the first batch, Element that is found in the up side (0, 1) = 5 (1, 2) = 7 (2, 3) = 7 (3, 4) = 3 (4, 1) = 9 (1, 3) = 7 (3, 2) = 8 (2, 0) = 5 Replace it with the corresponding element in down, and get the first batch, the first pomo matrix is

 [[6, 5, 7, 9, 1],
    [7, 6, 7, 7, 9],
    [5, 3, 8, 7, 3],
    [0, 8, 8, 3, 3],
    [8, 9, 4, 8, 1]

In addition, the simple way of the realization of multiple for loop more time-consuming, is there anyone can use the torch, gather, torch. Scatter_ function such as batch rapid implementation way?

One points to note: The edge consisting of the last and first point of the selected_node_list sequence is also the edge that needs to be replaced.

complete output down is:

 tensor([[[[6, 5, 7, 9, 1],
              [7, 6, 7, 7, 9],
              [5, 3, 8, 7, 3],
              [0, 8, 8, 3, 3],
              [8, 9, 4, 8, 1]],
    
             [[7, 2, 1, 4, 0],
              [0, 7, 5, 3, 9],
              [4, 8, 6, 6, 4],
              [9, 6, 2, 2, 3],
              [3, 5, 2, 7, 8]]]]) 
     

Solution

  • Key idea: calculate index and choose value from UP if index matches

    You can use torch.where function to achieve this.

    However, it is hard to use torch.where on condition depending on multiple dimensions, so we can flatten all items first, then remake into original shape after calculation.

    Flatten tensors

    batch_size, pomo_size, problem_size = up.size()[:-1]
    select_size = selected_node_list.size()[-1] - 1
    index_flattened = selected_node_list[:, :, :-1] * problem_size + selected_node_list[:, :, 1:]
    up_flattened = up.view(batch_size, pomo_size, problem_size ** 2) # batch size, pomo size, problem size ^ 2. actually you can use -1 for last argument
    down_flattened = down.view(batch_size, pomo_size, -1) # batch size, pomo size, problem size ^ 2
    print(index_flattened)
    

    Output: tensor([[[ 1, 7, 13, 19, 21, 8, 17], [ 5, 3, 19, 20, 2, 14, 23]]])

    At this moment, up and down tensor has (batch size, pomo size, problem size ^ 2) shape, and index is also recalculated based on pairs.

    Create mask tensors

    We should make a tensor that has same shape as (batch size, pomo size, problem size ^ 2), and has 1 if we should pick value from up and 0 otherwise.

    Actually, it is easy to create mask tensor with for loop, but you can create without with some complicated ways... Making multiple one-hot vectors per index and summing them.

    pivots = torch.arange(0, problem_size ** 2).view(1, 1, 1, -1).expand(batch_size, pomo_size, select_size, -1)
    index = index_flattened.unsqueeze(-1).expand(-1, -1, -1, problem_size ** 2)
    mask_flattened = torch.where(pivots == index, torch.ones_like(pivots), torch.zeros_like(pivots))
    mask_flattened = torch.sum(mask_flattened, dim=-2)
    print(mask_flattened)
    

    Output: tensor([[[0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0], [0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0]]])

    Calculate result and reshape

    out = torch.where(mask_flattened == 1, up_flattened, down_flattened) # get value from up if mask_flattened has value 1
    
    out = out.view(batch_size, pomo_size, problem_size, problem_size) # reshape
    
    print(out)
    

    Output

    tensor([[[[6, 5, 7, 9, 1],
              [7, 6, 7, 7, 9],
              [5, 3, 8, 7, 3],
              [0, 8, 8, 3, 3],
              [8, 9, 4, 8, 1]],
    
             [[7, 2, 1, 4, 0],
              [0, 7, 5, 3, 9],
              [4, 8, 6, 6, 4],
              [9, 3, 2, 2, 3],
              [3, 5, 2, 7, 8]]]])
    

    And this is what you want to see.

    Edited 2023.06.18

    I missed I have to change one more thing.

    You can add selected_node_list = torch.concat([selected_node_list, selected_node_list[:, :, 0].unsqueeze(-1)], dim=-1) at the very beginning to work correctly.