Search code examples
k-meanstorch

torch7: Unexpected 'counts' in k-Means Clustering


I am trying to apply k-means clustering on a set of images (images are loaded as float torch.Tensors) using the following segment of code:

print('[Clustering all samples...]')
local points = torch.Tensor(trsize, 3, 221, 221)
for i = 1,trsize do
  points[i] = trainData.data[i]:clone() -- dont want to modify the original tensors
end
points:resize(trsize, 3*221*221) -- to convert it to a 2-D tensor
local centroids, counts = unsup.kmeans(points, total_classes, 40, total_classes, nil, true)
print(counts)

When I observe the values in the counts tensor, I observe that it contains unexpected values, in the form of some entries being more than trsize, whereas the documentation says that counts stores the counts per centroid. I expected that it means counts[i] equals the number of samples out of trsize belonging to cluster with centroid centroids[i]. Am I wrong in assuming so?

If that indeed is the case, shouldn't sample-to-centroid be a hard-assignment (i.e. shouldn't counts[i] sum to trsize, which clearly is not the case with my clustering)? Am I missing something here?

Thanks in advance.


Solution

  • In the current version of the code, counts are accumulated after each iteration

    for i = 1,niter do
      -- k-means computations...
    
      -- total counts
      totalcounts:add(counts)
    end
    

    So in the end counts:sum() is a multiple of niter.

    As a workaround you can use the callback to obtain the final counts (non-accumulated):

    local maxiter = 40
    
    local centroids, counts = unsup.kmeans(
      points,
      total_classes,
      maxiter,
      total_classes,
      function(i, _, totalcounts) if i < maxiter then totalcounts:zero() end end,
      true
    )
    

    As an alternative you can use vlfeat.torch and explicitly quantize your input points after kmeans to obtain these counts:

    local assignments = kmeans:quantize(points)
    
    local counts = torch.zeros(total_classes):int()
    
    for i=1,total_classes do
      counts[i] = assignments:eq(i):sum()
    end