Search code examples
rggplot2cluster-computingsilhouette

How can I change the color to a variable other than cluster number in fviz_silhouette


I am using package factoextra in R to generate a silhouette plot. Currently the silhouette automatically colours your graph via clustering. I want to color by another variable Site which I have defined as x. I have tried both fill and color changing it to the site variable but nothing seems to work. I have also tried using scale_color_manual and scale_fill_discrete. I think the key is in scale_fill_discrete as the source code from another user points out

mapping <- aes_string(x = "name", y = "sil_width", color = "cluster", fill = "cluster")

I basically need to change color="cluster" to colour= "x".
I have reverted the final plotting code back to its most basic form.

pamspec <- pam(spec, 3, keep.diss = TRUE) 
plot(pamspec)

spec <- cbind(pamspec$clustering)
autoplot(pam(spec,3), frame=TRUE, frame.type = "norm")

pamspec$site <- spec$Site
x <- pamspec$site
fviz_silhouette(pamspec, label=TRUE) + theme_classic()

Solution

  • I suggest the following modified version of fviz_silhouette with the additional var.col input argument.

    myfviz_silhouette <- function (sil.obj, var.col, label = FALSE, print.summary = TRUE, ...) {
        if (inherits(sil.obj, c("eclust", "hcut", "pam", "clara", 
            "fanny"))) {
            df <- as.data.frame(sil.obj$silinfo$widths, stringsAsFactors = TRUE)
        }
        else if (inherits(sil.obj, "silhouette")) 
            df <- as.data.frame(sil.obj[, 1:3], stringsAsFactors = TRUE)
        else stop("Don't support an oject of class ", class(sil.obj))
        df <- df[order(df$cluster, -df$sil_width), ]
        if (!is.null(rownames(df))) 
            df$name <- factor(rownames(df), levels = rownames(df))
        else df$name <- as.factor(1:nrow(df))
        df$cluster <- as.factor(df$cluster)
        df$var_col <- var.col
        mapping <- aes_string(x = "name", y = "sil_width", color = "var_col", 
            fill = "var_col")
        p <- ggplot(df, mapping) + geom_bar(stat = "identity") + 
            labs(y = "Silhouette width Si", x = "", title = paste0("Clusters silhouette plot ", 
                "\n Average silhouette width: ", round(mean(df$sil_width), 
                    2))) + ggplot2::ylim(c(NA, 1)) + geom_hline(yintercept = mean(df$sil_width), 
            linetype = "dashed", color = "red")
        p <- ggpubr::ggpar(p, ...)
        if (!label) 
            p <- p + theme(axis.text.x = element_blank(), axis.ticks.x = element_blank())
        else if (label) 
            p <- p + theme(axis.text.x = element_text(angle = 45))
        ave <- tapply(df$sil_width, df$cluster, mean)
        n <- tapply(df$cluster, df$cluster, length)
        sil.sum <- data.frame(cluster = names(ave), size = n, ave.sil.width = round(ave, 
            2), stringsAsFactors = TRUE)
        if (print.summary) 
            print(sil.sum)
        p
    }
    

    Here is an example of its use:

    library(factoextra)
    library(cluster)
    pamspec <- pam(iris[,-5], 3, keep.diss = TRUE) 
    
    color_var <- iris$Species
    myfviz_silhouette(pamspec, color_var, label=TRUE) +
       theme_classic()
    

    enter image description here