Search code examples
rggplot2patchwork

Use patchwork to create a facet grid with strips


I'm using a specific function that returns a ggplot in a grid-search algorithm, and I want to arrange the resulting plot in a plot grid, like facet_grid() would do.

I cannot use standard facets as each scenario can only be interpreted in a specific scale, so each plot has independent x/y limits.

Here is a reproducible example:

library(tidyverse)
library(patchwork)
get_plot = function(a, b) list(ggplot(iris, aes(Sepal.Length, Sepal.Width)) +
                                 geom_point() + ggtitle(paste0(a, "--", b)))
x = expand_grid(dist_fun = c("rnorm", "rexp"),
                assumption=c("linear", "square")) %>%
  rowwise() %>%
  mutate(plot=get_plot(dist_fun, assumption))
x
#> # A tibble: 4 × 3
#> # Rowwise: 
#>   dist_fun assumption plot  
#>   <chr>    <chr>      <list>
#> 1 rnorm    linear     <gg>  
#> 2 rnorm    square     <gg>  
#> 3 rexp     linear     <gg>  
#> 4 rexp     square     <gg>

#expected outcome, missing strips:
wrap_plots(x$plot)

Created on 2023-12-08 with reprex v2.0.2

Is there a way to replace the ggtitle identification with facet-like strips?


Solution

  • Well, that's not about which function to use. (: To get a facet_grid like look using patchwork requires that you write yourself a custom function to manipulate your ggplot objects before passing them to wrap_plots. To this end I first added a column containing the numeric position in the patch. Then, using a custom function you could remove the axes and add the strips conditional on the position:

    Note: While I love such patchwork exercises, IMHO the easier approach would be to return the data from your function or a list containing both data and the plot, then use rbind/bind_rows + facet_grid as already suggested in the comments.

    library(tidyverse)
    library(patchwork)
    
    get_plot <- function(a, b) {
      list(ggplot(iris, aes(Sepal.Length, Sepal.Width)) +
        geom_point())
    }
    x <- expand_grid(
      dist_fun = c("rnorm", "rexp"),
      assumption = c("linear", "square")
    ) %>%
      mutate(plot = get_plot(dist_fun, assumption))
    
    x |>
      mutate(pos = row_number()) |>
      pmap(
        \(...) {
          args <- list(...)
          col <- 1 + (args$pos + 1) %% 2
          row <- 1 + (args$pos - 1) %/% 2
    
          remove_x <- if (row == 1) {
            theme(
              axis.text.x = element_blank(),
              axis.ticks.x = element_blank(),
              axis.ticks.length.x = unit(0, "pt")
            )
          }
          remove_y <- if (col == 2) {
            theme(
              axis.text.y = element_blank(),
              axis.ticks.y = element_blank(),
              axis.ticks.length.y = unit(0, "pt")
            )
          }
    
          facet_y <- if (row == 1) {
            # Wrap in quotes to get e.g.
            # "linear" ~ . instead of linear ~ .
            paste0("\"", args$assumption, "\"")
          } else {
            "."
          }
          facet_x <- if (col == 2) {
            paste0("\"", args$dist_fun, "\"")
          } else {
            "."
          }
    
          layer_facet <- facet_grid(reformulate(facet_y, facet_x))
    
          args$plot +
            remove_x +
            remove_y +
            layer_facet
        }
      ) |>
      wrap_plots(ncol = 2) &
      theme(
        axis.title = element_blank()
      )