I am trying to implement a Self-Organizing Map where for a given input sample, the best matching unit/winning unit is chosen based on (say) L2-norm distance between the SOM and the input. To implement this, I have:
# Input batch: batch-size = 512, input-dim = 84-
z = torch.randn(512, 84)
# SOM shape: (height, width, input-dim)-
som = torch.randn(40, 40, 84)
# Compute L2 distance for a single sample out of 512 samples-
dist_l2 = np.linalg.norm((som.numpy() - z[0].numpy()), ord = 2, axis = 2)
# dist_l2.shape
# (40, 40)
# Get (row, column) index of the minimum of a 2d np array-
row, col = np.unravel_index(dist_l2.argmin(), dist_l2.shape)
print(f"BMU for z[0]; row = {row}, col = {col}")
# BMU for z[0]; row = 3, col = 9
So for the first input sample of 'z', the winning unit in SOM has the index: (3, 9). I can put this in a for loop iterating over all 512 such input samples, but that is very inefficient.
Is there an efficient vectorized PyTorch manner to compute this for the entire batch?
You can easily extend this operation to batches by expanding your som
tensor:
_som = som.view(1,-1,z.size(-1)).expand(len(z),-1,-1)
# L2((512, 1600, 84), (512, 1, 84)) = (512, 1600, 1)
dist_l2 = torch.cdist(_som, z[:,None])[:,:,0]
# both are shaped (512,)
row, col = torch.unravel_index(dist_l2.argmin(1), (40,40))
Note: torch.unravel_index
is available from PyTorch version 2.2, if you don't have access to this version, you may resort to this user-made implementation.