Search code examples
pythonmathscikit-learnk-means

Kmeans: Reassign data point to second nearest?


I have a trained Scikit Kmean model.

When using the models predict-function, the model assigns a given data point to the nearest cluster. (As expected)

What is the easiest method to instead have the model assign the data point to the SECOND nearest, or THIRD nearest cluster?

I cannot seem to find this anywhere. (I might be missing something essential.)


Solution

  • The Kmeans classifier has a transform(X) method that returns the distance of each record to the centroids of each cluster, in the form of an array with the shape [n_observations, n_clusters].

    With that, you can pick which cluster to assign the records to.

    Example:

    import numpy as np
    from sklearn.cluster import KMeans
    from sklearn.datasets import load_digits
    from sklearn.preprocessing import scale
    
    np.random.seed(42)
    
    digits = load_digits()
    data = scale(digits.data)
    n_digits = len(np.unique(digits.target))
    
    km = KMeans(init='k-means++', n_clusters=n_digits, n_init=10)
    km.fit(data)
    predicted = km.predict(data)
    dist_centers = km.transform(data)
    

    To validate the transform output, we can compare the result of predict to taking the minimum value of the centroid distances:

    >>> np.allclose(km.predict(data), np.argmin(dist_centers, axis=1))
    True
    

    Finally, we can use np.argsort to get the index of the sorted elements of each row in the distances array in such a way that the first column of the result corresponds to the labels of the nearest clusters, the second column corresponds to the labels of the second nearest clusters, and so on.

    >>> print(predicted)
    [0 3 3 ... 3 7 7]
    
    >>> print(np.argsort(dist_centers, axis=1))
    [[0 7 4 ... 8 6 5]
     [3 9 4 ... 6 0 5]
     [3 9 4 ... 8 6 5]
     ...
     [3 1 9 ... 8 6 5]
     [7 0 9 ... 8 6 5]
     [7 3 1 ... 9 6 5]]