Search code examples
rggplot2patchwork

patchwork: generating common legend


I have generated 3 plots and would like to combine them. The code for the plots is below:

dput(ps)
structure(list(out = c(0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 
0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 
1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 
1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 
0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 
1, 0, 1), p1 = c(0.203288563788086, 0.593397908262453, 0.104478608490232, 
0.191949121414957, 0.131451273581716, 0.197662840772138, 0.475271986907143, 
0.999674001349968, 0.464228555892293, 0.307784508243616, 0.483972218264885, 
0.592054031086974, 0.519394485828694, 0.140996420764961, 0.715191009167964, 
0.161447437666636, 0.215532480635095, 0.512141004164495, 0.325618200726227, 
0.147688571495122, 0.587060672345883, 0.339084925099401, 0.406897042874474, 
0.260334679756635, 0.600584657342952, 0.30091607683028, 0.194842962759756, 
0.466005093086476, 0.62157971598547, 0.103213729114441, 0.611973079920802, 
0.340797053901145, 0.352770449747561, 0.782266356700567, 0.925336385275883, 
0.566815799016511, 0.911716642419561, 0.896102815901494, 0.985634484118366, 
0.269398492713071, 0.454145664817772, 0.258300464775039, 0.304414866253428, 
0.146170872786592, 0.603260865814924, 0.999999646567147, 0.509381303111335, 
0.233160200071857, 0.61050648554401, 0.140024461566546, 0.785072510664678, 
0.32209696620784, 0.211704380115593, 0.915054151530585, 0.280059805711919, 
0.260800073348821, 0.264564935816099, 0.364786135876787, 0.314587955084217, 
0.873181475153794, 0.212562097727904, 0.161257570078657, 0.239011201339171, 
0.245951130922908, 0.192314834634348, 0.431669247737303, 0.178673765213575, 
0.472989629383862, 0.22881006859004, 0.327803704637798, 0.140396326002643, 
0.229542264680936, 0.368074793643649, 0.251451892404397, 0.311004165920044, 
0.32141576096675, 0.320365845919904, 0.605394241031366, 0.313651664046354, 
0.383296458601734, 0.19724411942929, 0.264585196825289, 0.189331523639336, 
0.619571987070894, 0.219173976526817, 0.117791664602952, 0.165759812695153, 
0.0937050368823814, 0.954729558365975, 0.13564603779449, 0.222788963402769, 
0.174618457520783, 0.122777343457065, 0.379096358167151, 0.364369567415599, 
0.236639744450608, 0.449747269585931, 0.283437170381281, 0.0996580311901003, 
0.258892474371184), p2 = c(0.201791866675204, 0.528131702109111, 
0.132772837212832, 0.200780117697992, 0.147396315698057, 0.181403575768733, 
0.43442493789959, 0.999085024187367, 0.40087851525752, 0.285365004598898, 
0.451359257257437, 0.538584320447704, 0.48997782458568, 0.130097536963362, 
0.661516070857524, 0.169690532818658, 0.21913925575274, 0.552988203822168, 
0.305838305947951, 0.153999479719914, 0.558383678542237, 0.320063561525497, 
0.432599852851054, 0.253416941136616, 0.575173832168762, 0.295092969045072, 
0.201784269790454, 0.477022157405156, 0.56773491016289, 0.134866444444684, 
0.562192999847133, 0.348823388527, 0.339713966428903, 0.704022287840662, 
0.87470569354584, 0.558931522430035, 0.874036113779291, 0.876785599057654, 
0.973632162303989, 0.361796325757981, 0.460638802413689, 0.250744460743195, 
0.29034861228415, 0.156473978595486, 0.559039050288004, 0.999998587316118, 
0.442908959434392, 0.23202002509723, 0.562833009542875, 0.147396315698057, 
0.759842576555407, 0.258880616387473, 0.240729865505789, 0.875116692869777, 
0.274955888823107, 0.236289396383835, 0.258880616387473, 0.347654802280353, 
0.305838305947951, 0.743842630655276, 0.280497565858943, 0.169690532818658, 
0.223921536816737, 0.269357444122176, 0.197653306295674, 0.396919081721598, 
0.180211174706966, 0.447714291526475, 0.223406882796942, 0.317294821538698, 
0.147396315698057, 0.238651843622144, 0.345312218006159, 0.253970749804362, 
0.289735464316524, 0.327504170149824, 0.351002827274531, 0.604492744395317, 
0.269357444122176, 0.362591727588704, 0.212505029384632, 0.258880616387473, 
0.211019857075736, 0.568392478332797, 0.218415081445013, 0.15441240879629, 
0.18475833091903, 0.124471327983545, 0.937876567615899, 0.138259434665826, 
0.243991464941914, 0.234782614018068, 0.134866444444684, 0.333873238049351, 
0.338341008698206, 0.234063226687198, 0.421619679768826, 0.273939193597239, 
0.116505493530719, 0.255116520892176)), row.names = c(NA, -100L
), class = "data.frame")

dput(ps.long)
structure(list(out = c(0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 
0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 
1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 
1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 
0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 
1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 
1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 
0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 
0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 
1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1), pred = c(0.203288563788086, 
0.593397908262453, 0.104478608490232, 0.191949121414957, 0.131451273581716, 
0.197662840772138, 0.475271986907143, 0.999674001349968, 0.464228555892293, 
0.307784508243616, 0.483972218264885, 0.592054031086974, 0.519394485828694, 
0.140996420764961, 0.715191009167964, 0.161447437666636, 0.215532480635095, 
0.512141004164495, 0.325618200726227, 0.147688571495122, 0.587060672345883, 
0.339084925099401, 0.406897042874474, 0.260334679756635, 0.600584657342952, 
0.30091607683028, 0.194842962759756, 0.466005093086476, 0.62157971598547, 
0.103213729114441, 0.611973079920802, 0.340797053901145, 0.352770449747561, 
0.782266356700567, 0.925336385275883, 0.566815799016511, 0.911716642419561, 
0.896102815901494, 0.985634484118366, 0.269398492713071, 0.454145664817772, 
0.258300464775039, 0.304414866253428, 0.146170872786592, 0.603260865814924, 
0.999999646567147, 0.509381303111335, 0.233160200071857, 0.61050648554401, 
0.140024461566546, 0.785072510664678, 0.32209696620784, 0.211704380115593, 
0.915054151530585, 0.280059805711919, 0.260800073348821, 0.264564935816099, 
0.364786135876787, 0.314587955084217, 0.873181475153794, 0.212562097727904, 
0.161257570078657, 0.239011201339171, 0.245951130922908, 0.192314834634348, 
0.431669247737303, 0.178673765213575, 0.472989629383862, 0.22881006859004, 
0.327803704637798, 0.140396326002643, 0.229542264680936, 0.368074793643649, 
0.251451892404397, 0.311004165920044, 0.32141576096675, 0.320365845919904, 
0.605394241031366, 0.313651664046354, 0.383296458601734, 0.19724411942929, 
0.264585196825289, 0.189331523639336, 0.619571987070894, 0.219173976526817, 
0.117791664602952, 0.165759812695153, 0.0937050368823814, 0.954729558365975, 
0.13564603779449, 0.222788963402769, 0.174618457520783, 0.122777343457065, 
0.379096358167151, 0.364369567415599, 0.236639744450608, 0.449747269585931, 
0.283437170381281, 0.0996580311901003, 0.258892474371184, 0.201791866675204, 
0.528131702109111, 0.132772837212832, 0.200780117697992, 0.147396315698057, 
0.181403575768733, 0.43442493789959, 0.999085024187367, 0.40087851525752, 
0.285365004598898, 0.451359257257437, 0.538584320447704, 0.48997782458568, 
0.130097536963362, 0.661516070857524, 0.169690532818658, 0.21913925575274, 
0.552988203822168, 0.305838305947951, 0.153999479719914, 0.558383678542237, 
0.320063561525497, 0.432599852851054, 0.253416941136616, 0.575173832168762, 
0.295092969045072, 0.201784269790454, 0.477022157405156, 0.56773491016289, 
0.134866444444684, 0.562192999847133, 0.348823388527, 0.339713966428903, 
0.704022287840662, 0.87470569354584, 0.558931522430035, 0.874036113779291, 
0.876785599057654, 0.973632162303989, 0.361796325757981, 0.460638802413689, 
0.250744460743195, 0.29034861228415, 0.156473978595486, 0.559039050288004, 
0.999998587316118, 0.442908959434392, 0.23202002509723, 0.562833009542875, 
0.147396315698057, 0.759842576555407, 0.258880616387473, 0.240729865505789, 
0.875116692869777, 0.274955888823107, 0.236289396383835, 0.258880616387473, 
0.347654802280353, 0.305838305947951, 0.743842630655276, 0.280497565858943, 
0.169690532818658, 0.223921536816737, 0.269357444122176, 0.197653306295674, 
0.396919081721598, 0.180211174706966, 0.447714291526475, 0.223406882796942, 
0.317294821538698, 0.147396315698057, 0.238651843622144, 0.345312218006159, 
0.253970749804362, 0.289735464316524, 0.327504170149824, 0.351002827274531, 
0.604492744395317, 0.269357444122176, 0.362591727588704, 0.212505029384632, 
0.258880616387473, 0.211019857075736, 0.568392478332797, 0.218415081445013, 
0.15441240879629, 0.18475833091903, 0.124471327983545, 0.937876567615899, 
0.138259434665826, 0.243991464941914, 0.234782614018068, 0.134866444444684, 
0.333873238049351, 0.338341008698206, 0.234063226687198, 0.421619679768826, 
0.273939193597239, 0.116505493530719, 0.255116520892176), Model = structure(c(1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 
1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 
2L, 2L, 2L, 2L, 2L, 2L, 2L), .Label = c("Full", "Restricted"), class = "factor")), row.names = c(NA, 
-200L), class = "data.frame")

#------------------------------------------------------------------------------------------------------------------------#

# Load packages

library(precrec)
library(ggplot2)
library(patchwork)
library(showtext)

#------------------------------------------------------------------------------------------------------------------------#

# Load random font

font_add_google("EB Garamond")
showtext_auto()

#------------------------------------------------------------------------------------------------------------------------#

# Combine scores from different models

scores <- with(ps, join_scores(p1, p2))
mod <- mmdata(scores = scores, labels = ps$out, modnames = c("Full", "Restricted"))
curves <- evalmod(mod)

#------------------------------------------------------------------------------------------------------------------------#

# Plot ROC curve

roc <- autoplot(curves, curvetype = "ROC") + 
  labs(x = "False positive rate", 
       y = "True positive rate") + 
  scale_color_manual("Model", values = c("steelblue", "red")) + 
  theme(title = element_blank(), 
        axis.text = element_text(size = 10, family = "EB Garamond"), 
        axis.line = element_line(colour = "black"), 
        axis.title.x = element_text(size = 10, family = "EB Garamond", vjust = -1), 
        axis.title.y = element_text(size = 10, family = "EB Garamond", vjust = 2), 
        panel.background = element_blank(), 
        panel.spacing.y = unit(2, "lines"), 
        panel.border = element_blank(), 
        panel.grid = element_blank(), 
        strip.background = element_blank(), 
        strip.text = element_blank(), 
        legend.position = "bottom", 
        legend.key = element_blank(), 
        legend.title = element_text(size = 10, family = "EB Garamond"), 
        legend.text = element_text(size = 10, family = "EB Garamond")) + 
  scale_x_continuous(limits = c(0, 1), labels = c("0", ".25", ".50", ".75", "1")) + 
  scale_y_continuous(limits = c(0, 1), labels = c("0", ".25", ".50", ".75", "1"))


# Plot PR curve

pr <- autoplot(curves, curvetype = "PRC") + 
  labs(x = "Recall", 
       y = "Precision") + 
  scale_color_manual("Model", values = c("steelblue", "red")) + 
  theme(title = element_blank(), 
        axis.text = element_text(size = 10, family = "EB Garamond"), 
        axis.line = element_line(colour = "black"), 
        axis.title.x = element_text(size = 10, family = "EB Garamond", vjust = -1), 
        axis.title.y = element_text(size = 10, family = "EB Garamond", vjust = 2), 
        panel.background = element_blank(), 
        panel.spacing.y = unit(2, "lines"), 
        panel.border = element_blank(), 
        panel.grid = element_blank(), 
        strip.background = element_blank(), 
        strip.text = element_blank(), 
        legend.position = "bottom", 
        legend.key = element_blank(), 
        legend.title = element_text(size = 10, family = "EB Garamond"), 
        legend.text = element_text(size = 10, family = "EB Garamond")) + 
  scale_x_continuous(limits = c(0, 1), labels = c("0", ".25", ".50", ".75", "1")) + 
  scale_y_continuous(limits = c(0, 1), labels = c("0", ".25", ".50", ".75", "1"))


# Calibration plot

cal <- ggplot(data = ps.long, 
              aes(x = pred, 
                  y = out,
                  color = Model)) + 
  geom_smooth(aes(x = pred, y = out), size = .5, se = F, method = "loess") + 
  geom_abline(linetype = "dotted", alpha = .5) + 
  scale_color_manual("Model", values = c("steelblue", "red")) + 
  labs(x = "Predicted probability of outcome", 
       y = "Outcome") + 
  theme(axis.text = element_text(size = 10, family = "EB Garamond"), 
        axis.line = element_line(colour = "black"), 
        axis.title.x = element_text(size = 10, family = "EB Garamond", vjust = -1), 
        axis.title.y = element_text(size = 10, family = "EB Garamond", vjust = 2), 
        legend.position = "bottom", 
        legend.key = element_blank(), 
        legend.title = element_text(size = 10, family = "EB Garamond"), 
        legend.text = element_text(size = 10, family = "EB Garamond"),  
        panel.background = element_blank(), 
        strip.background = element_blank(), 
        strip.text = element_blank()) + 
  scale_x_continuous(limits = c(0, 1), labels = c("0", ".25", ".50", ".75", "1")) + 
  scale_y_continuous(limits = c(0, 1), labels = c("0", ".25", ".50", ".75", "1"))

Using patchwork

roc + pr + cal + 
  plot_annotation(tag_levels = "A") + 
  plot_layout(guides = "collect") & 
  theme(plot.tag = element_text(family = "EB Garamond", size = 8), 
        legend.position = "bottom")

I can generate this figure with two legends:

enter image description here

However, I would like the figure to have only one legend. How can this be done?


Solution

  • Here is a way how you can tweak the legends within a patchwork call:

    (roc + theme(legend.position = "none")) + pr + (cal + theme(legend.position = "none")) +
      plot_annotation(tag_levels = "A") +
      plot_layout(guides = "auto") & 
      theme(plot.tag = element_text(family = "EB Garamond", size = 8))
    

    enter image description here