Search code examples
rggplot2legendpatchwork

Collecting common legends while keeping a specific legend beside each plot in patchwork


I am making a multi-panel figure in R using wrap_plots from the patchwork package, and I'd like to have the final plot show a common legend for point shape and color, but have each plot keep its own fill legend. Here's some mock data:

library(ggplot2)
library(dplyr)
library(patchwork)
library(ggrepel)

df1 <- data.frame(site = rep(c("a","b","c"),each=10),
                  treat = rep(c("treat_1","treat_2"),each=5),
                  species = rep(c("spA","spB","spC","spD","spE")),
                  val = rpois(30,10))
df1$val2 <- NA
df1$val2[df1$species == "spA"] <- 15
df1$val2[df1$species == "spB"] <- 10
df1$val2[df1$species == "spC"] <- 12
df1$val2[df1$species == "spD"] <- 8
df1$val2[df1$species == "spE"] <- 5
df1$top <- NA
df1$top[df1$species == "spA"] <- "top_1"
df1$top[df1$species == "spB"] <- "top_2"
df1$top[df1$species == "spC"] <- "top_3"
df1$top[df1$species == "spD"] <- "none"
df1$top[df1$species == "spE"] <- "none"

### Contour plot
contour_plot <- function(data, site_name) {
  # Filter the data for the specified site
  site_data <- data %>% filter(site == site_name)
  
  # Pulling out ranges
  value_2 <- seq(min(site_data$val2), max(site_data$val2), length.out = 50)
  value_1 <- seq(min(site_data$val), max(site_data$val), length.out = 50)
  
  # Merge data into dataframe
  grid_df <- expand.grid(anomaly = value_2, abun = value_1)
  grid_df <- grid_df %>%
    mutate(total = anomaly * abun)
  
  # Get unique categories
  unique_treatments <- unique(site_data$treat)
  
  top_contributors <- site_data %>% 
    filter(top == "top_1" | top == "top_2" | top == "top_3") %>%
    filter(treat == "treat_1")
  
  # Plot
  p <- ggplot(grid_df, aes(x = anomaly, y = abun)) +
    geom_tile(aes(fill = total)) + # Add tiles to represent the surface
    geom_contour(aes(z = total), color = "gray50") + # Specified contour lines
    geom_point(data = site_data, aes(x = val2, y = val, color =treat,alpha = top, shape = top),size=3) +
    scale_alpha_manual(values = c("top_1" = 1, "top_2" = 1, "top_3" = 1, "none" = 0.4),
                       labels = c("top_1" = "Top 1", "top_2" = "Top 2", "top_3" = "Top 3", "none" = "N/A")) +
    scale_shape_manual(values = c("top_1" = 8, "top_2" = 18, "top_3" = 17, "none" = 20),
                       labels = c("top_1" = "Top 1", "top_2" = "Top 2", "top_3" = "Top 3", "none" = "N/A")) +
    scale_fill_gradient2(
      low = "blue", mid = "white", high = "red", midpoint = 0,
    ) +
    ggtitle(site_name) +
    geom_label_repel(data = top_contributors,
                     aes(x = val2, y = val, label = species),
                     size = 3,
                     box.padding = 0.5,          # Increase the padding around the box
                     point.padding = 0.3,   
                     fill = "white",             # Background color of the label
                     color = "black",
                     nudge_x = 0.1,          # Slightly nudge labels in the x-direction if needed
                     nudge_y = 0.0001,  
                     min.segment.length = unit(0, 'lines'),
                     segment.color = 'grey50') +
    theme_minimal() +
    theme(axis.text.x = element_text(size = 12),
          axis.text.y = element_text(size = 14),
          axis.title = element_text(size = 14, face = "bold")) +
    coord_cartesian(expand = FALSE) # Ensure no expansion on axes
  
  # Print the plot
  print(p)
}

# Loop through each site and store the output plot
# Initialize an empty list to store the plots
plots_contours <- list()

for (site_name in unique(df1$site)) {
  # Generate the plot for the current site
  plot <- contour_plot(df1, site_name)
  
  # Store the plot in the list with the site name as the key
  plots_contours[[site_name]] <- plot
}

### Save figure outputs
plot_a <- plots_contours[[1]]
plot_b <- plots_contours[[2]]
plot_c <- plots_contours[[3]]

### Multi-panel figure
wrap_plots(plot_a,plot_b,plot_c,
           ncol=2,nrow=2) + plot_layout(guides = "collect",axis_titles = "collect")

This correctly plots a shared legend for point shape and color, but it duplicates the fill legend. I'd ultimately like to have each figure in the multi-panel figure show its own fill guide, such as plots_contours[[1]] + guides(shape = "none", color = "none",alpha="none"), with a shared legend for point shape and color. current output


Solution

  • library(ggpubr) ## additional required package
    
    ## function to remove color, shape, and alpha guides
    legend_rem <- function(plt){
      plt <- plt +
        guides(shape = "none", color = "none", alpha = "none")
      return(plt)   
    }
    
    ## getting the color, shape-alpha guide separately
    clr_shp_lgnd <- get_legend(plot_a + guides(fill = "none"))
    
    ## patchwork design
    design <- "
      aaaaaabbbbbb##
      aaaaaabbbbbbdd
      aaaaaabbbbbbdd
      cccccc######dd
      cccccc######dd
      cccccc########
    "
    
    legend_rem(plot_a) + 
      legend_rem(plot_a) + 
      legend_rem(plot_a) + 
      as_ggplot(clr_shp_lgnd) +
      plot_layout(design = design)
    

    If you want to collect the axis titles, we can create the the color and shape legend along with the axis titles (there could be other ways, like annotating, but this is what I could think of):

    clr_shp_lgnd <- plot_a +
      guides(fill = "none") +
      theme(axis.text.y = element_blank(),
            plot.title = element_blank(),
            axis.text.x=element_blank(),
            axis.ticks = element_blank(),
            legend.position = c(0.5, 0.5),
            legend.title = element_blank(),
            panel.grid = element_blank(),
            panel.border = element_rect(colour = "white", fill='white', size=1)) 
    
    
    design <- "
      aaaaaabbbbbb##
      aaaaaabbbbbbdd
      aaaaaabbbbbbdd
      cccccc######dd
      cccccc######dd
      cccccc########
    "
    
    legend_rem(plot_a) + 
      legend_rem(plot_b) + 
      legend_rem(plot_c) + 
      clr_shp_lgnd +
      plot_layout(design = design, 
                  axis_titles = "collect")