Search code examples
cluster-analysis

How to apply the clustering solution using k-medoids algorithm (for example PAM) to another dataset?


I'm looking for a way to apply the cluster solution from k- medoids algorithm (I'm using PAM) from one sample to another. I think this could be done for k-means algorithm: for data1, get the centroid from the clustering result; then in data2, for each observation, calculate the distance to each centroid and then assign each observation to its closest centroid. By doing this, we applied the clustering solution from data1 to data2. However, k- medoids algorithm (for example, PAM) uses medoids as cluster centers instead of means. In this case, it is not clear to me how to apply the clustering solution from one sample to another. Could anyone help answering this question? Many thanks!


Solution

  • Clusters are still assigned by distance to the centres, except with k-medoids, the centre is actually a datapoint in the dataset. See code in R below:

    library(ClusterR)
    library(ggplot2)
    set.seed(100)
    # we use the iris data set, split into 2
    a = sample(nrow(iris),90)
    data_b = iris[-a,1:4]
    data_a = iris[a,1:4]
    
    #perform k medoids
    cm = Cluster_Medoids(data_a,clusters=3)
    

    You can see the medoids are data points:

    cm$medoids
        Sepal.Length Sepal.Width Petal.Length Petal.Width
    95           5.6         2.7          4.2         1.3
    12           4.8         3.4          1.6         0.2
    111          6.5         3.2          5.1         2.0
    

    We go ahead and predict:

    pm = predict_Medoids(data_b,MEDOIDS=cm$medoids)
    

    And we can calculate distance between medoids from 1st dataset, and assign second dataset to the clusters:

    M = as.matrix(dist(rbind(cm$medoids,data_b)))
    labs = sapply(4:nrow(M),function(i)which.min(M[i,1:3]))
    

    We check and you can see, the clusters manually calculated agrees with the implemented in clusterR:

    table(pm$clusters==labs)
    
    TRUE 
      60 
    

    We can visualize this:

    PCA = prcomp(rbind(data_a,data_b))$x
    plotdf = data.frame(PCA[,1:2],
    label=c(cm$clusters,pm$clusters),
    dataset=rep(c("train","pred"),c(nrow(data_a),nrow(data_b)))
    )
    
    ggplot(plotdf,aes(x=PC1,y=PC2,col=factor(label),shape=dataset)) + 
    geom_point() + scale_color_brewer(palette="Paired") + theme_bw()
    

    enter image description here