Search code examples
rk-meanscentroidsubsampling

Retrieve 100 samples closest to the centroids of each cluster after K means clustering using R


I'm trying to reduce the input data size by first performing a K-means clustering in R then sample 50-100 samples per representative cluster for downstream classification and feature selection.

The original dataset was split 80/20, and then 80% went into K means training. I know the input data has 2 columns of labels and 110 columns of numeric variables. From the label column, I know there are 7 different drug treatments. In parallel, I tested the elbow method to find the optimal K for the cluster number, it is around 8. So I picked 10, to have more data clusters to sample for downstream.

Now I have finished running the model <- Kmeans(), the output list got me a little confused of what to do. Since I have to scale only the numeric variables to put into the kmeans function, the output cluster membership don't have that treatment labels anymore. This I can overcome by appending the cluster membership to the original training data table.

Then for the 10 centroids, how do I find out what the labels are? I can't just do

training_set$centroids <- model$centroids

And most important question, how do I find 100 samples per cluster that are the closeted to their respective centroid?? I have seen one post here in python but no R resources yet. Output 50 samples closest to each cluster center using scikit-learn.k-means library Any pointers?


Solution

  • First we need a reproducible example of your data:

    set.seed(42)
    x <- matrix(runif(150), 50, 3)
    kmeans.x <- kmeans(x, 10)
    

    Now you want to find the observations in original data x that are closest to the centroids computed and stored as kmeans.x. We use the get.knnx() function in package FNN. We will just get the 5 closest observations for each of the 10 clusters.

    library(FNN)
    y <- get.knnx(x, kmeans.x$centers, 5)
    str(y)
    # List of 2
    #  $ nn.index: int [1:10, 1:5] 42 40 50 22 39 47 11 7 8 16 ...
    #  $ nn.dist : num [1:10, 1:5] 0.1237 0.0669 0.1316 0.1194 0.1253 ...
    y$nn.index[1, ]
    # [1] 42 38  3 22 43
    idx1 <- sort(y$nn.index[1, ])
    cbind(idx1, x[idx1, ])
    #      idx1                          
    # [1,]    3 0.28614 0.3984854 0.21657
    # [2,]   22 0.13871 0.1404791 0.41064
    # [3,]   38 0.20766 0.0899805 0.11372
    # [4,]   42 0.43577 0.0002389 0.08026
    # [5,]   43 0.03743 0.2085700 0.46407
    

    The row indices of the nearest neighbors are stored in nn.index so for the first cluster, the 5 closest observations are 42, 38, 3, 22, 43.