Search code examples
pythonpytorchtensorindices

How to get indices of tensors of same value in a 2-d tensor?


As described in title, given a 2-d tensor, let's say:

tensor([
    [0, 1, 0, 1], # A
    [1, 1, 0, 1], # B
    [1, 0, 0, 1], # C
    [0, 1, 0, 1], # D
    [1, 1, 0, 1], # E
    [1, 1, 0, 1]  # F
])

That's easy enough to tell that "A and D", "B, E and F" are two groups of tensors,

that are of same value(that means A == D and B == E == F).

So my question is:

How to get indices of those groups?

Details:

Input: tensor above

Output: (0, 3), (1, 4, 5)


Solution

  • A solution using PyTorch functions:

    import torch
    
    x = torch.tensor([
        [0, 1, 0, 1], # A
        [1, 1, 0, 1], # B
        [1, 0, 0, 1], # C
        [0, 1, 0, 1], # D
        [1, 1, 0, 1], # E
        [1, 1, 0, 1]  # F
    ])
    
    _, inv, counts = torch.unique(x, dim=0, return_inverse=True, return_counts=True)
    print([tuple(torch.where(inv == i)[0].tolist()) for i, c, in enumerate(counts) if counts[i] > 1])
    # > [(0, 3), (1, 4, 5)]