Search code examples
pythonnumpypytorch

Find winning unit between 2 torch tensors of different shapes


I am trying to implement a Self-Organizing Map where for a given input sample, the best matching unit/winning unit is chosen based on (say) L2-norm distance between the SOM and the input. To implement this, I have:

# Input batch: batch-size = 512, input-dim = 84-
z = torch.randn(512, 84)

# SOM shape: (height, width, input-dim)-
som = torch.randn(40, 40, 84)

# Compute L2 distance for a single sample out of 512 samples-
dist_l2 = np.linalg.norm((som.numpy() - z[0].numpy()), ord = 2, axis = 2)

# dist_l2.shape
# (40, 40)

# Get (row, column) index of the minimum of a 2d np array-
row, col = np.unravel_index(dist_l2.argmin(), dist_l2.shape)

print(f"BMU for z[0]; row = {row}, col  = {col}")
# BMU for z[0]; row = 3, col  = 9

So for the first input sample of 'z', the winning unit in SOM has the index: (3, 9). I can put this in a for loop iterating over all 512 such input samples, but that is very inefficient.

Is there an efficient vectorized PyTorch manner to compute this for the entire batch?


Solution

  • You can easily extend this operation to batches by expanding your som tensor:

    _som = som.view(1,-1,z.size(-1)).expand(len(z),-1,-1)
    
    # L2((512, 1600, 84), (512, 1, 84)) = (512, 1600, 1)
    dist_l2 = torch.cdist(_som, z[:,None])[:,:,0]
    
    # both are shaped (512,)
    row, col = torch.unravel_index(dist_l2.argmin(1), (40,40))
    

    Note: torch.unravel_index is available from PyTorch version 2.2, if you don't have access to this version, you may resort to this user-made implementation.