Search code examples
rggplot2cowplot

Perfectly Align Density plots to scatterplot using cowplot


I am trying to build a function for bivariate plotting that taking 2 variables it is able to represent a marginal scatterplot and two lateral density plots.

The problem is that the density plot on the right does not align with the bottom axis.

Here is a sample data:

g1 = c(rnorm(200, mean=350, sd=100), rnorm(200, mean=700, sd=100))
g2 = c(rnorm(200, mean=350, sd=100), rnorm(200, mean=500, sd=100))
df_exp = data.frame(var1=log2(g1 + 1) , var2=log2(g2 + 1))

Here is the function:

    bivariate_plot <- function(df, var1, var2, density = T, box = F) {
    require(ggplot2)
    require(cowplot)
    scatter = ggplot(df, aes(eval(parse(text = var1)), eval(parse(text = var2)), color = "red")) +
            geom_point(alpha=.8)

    plot1 = ggplot(df, aes(eval(parse(text = var1)), fill = "red")) + geom_density(alpha=.5) 
    plot1 = plot1 + ylab("G1 density")

    plot2 = ggplot(df, aes(eval(parse(text = var2)),fill = "red")) + geom_density(alpha=.5) 
    plot2 = plot2 + ylab("G2 density")

    plot_grid(scatter, plot1, plot2, nrow=1, labels=c('A', 'B', 'C')) #Or labels="AUTO"


    # Avoid displaying duplicated legend
    plot1 = plot1 + theme(legend.position="none")
    plot2 = plot2 + theme(legend.position="none")

    # Homogenize scale of shared axes
    min_exp = min(df[[var1]], df[[var2]]) - 0.01
    max_exp = max(df[[var1]], df[[var2]]) + 0.01
    scatter = scatter + ylim(min_exp, max_exp)
    scatter = scatter + xlim(min_exp, max_exp)
    plot1 = plot1 + xlim(min_exp, max_exp)
    plot2 = plot2 + xlim(min_exp, max_exp)
    plot1 = plot1 + ylim(0, 2)
    plot2 = plot2 + ylim(0, 2)


    first_row = plot_grid(scatter, labels = c('A'))
    second_row = plot_grid(plot1, plot2, labels = c('B', 'C'), nrow = 1)
    gg_all = plot_grid(first_row, second_row, labels=c('', ''), ncol=1)

    # Display the legend
    scatter = scatter + theme(legend.justification=c(0, 1), legend.position=c(0, 1))



    # Flip axis of gg_dist_g2
    plot2 = plot2 + coord_flip()

    # Remove some duplicate axes
    plot1 = plot1 + theme(axis.title.x=element_blank(),
                          axis.text=element_blank(),
                          axis.line=element_blank(),
                          axis.ticks=element_blank())

    plot2 = plot2 + theme(axis.title.y=element_blank(),
                          axis.text=element_blank(),
                          axis.line=element_blank(),
                          axis.ticks=element_blank())

    # Modify margin c(top, right, bottom, left) to reduce the distance between plots
    #and align G1 density with the scatterplot
    plot1 = plot1 + theme(plot.margin = unit(c(0.5, 0, 0, 0.7), "cm"))
    scatter = scatter + theme(plot.margin = unit(c(0, 0, 0.5, 0.5), "cm"))
    plot2 = plot2 + theme(plot.margin = unit(c(0, 0.5, 0.5, 0), "cm"))

    # Combine all plots together and crush graph density with rel_heights
    first_col = plot_grid(plot1, scatter, ncol = 1, rel_heights = c(1, 3))
    second_col = plot_grid(NULL, plot2, ncol = 1, rel_heights = c(1, 3))
    perfect = plot_grid(first_col, second_col, ncol = 2, rel_widths = c(3, 1),
                        axis = "lrbl", align = "hv")

    print(perfect)
}

And here is the call for plotting:

bivariate_plot(df = df_exp, var1 = "var1", var2 = "var2")

It is important to point out that this alignment problem is always present even by changing the data.

enter image description here

And this is what happen with my real data: enter image description here


Solution

  • There's so many bugs in your code that I don't quite know where to start. The code below fixes them, to the extent that I understand what the intended result is.

    g1 = c(rnorm(200, mean=350, sd=100), rnorm(200, mean=700, sd=100))
    g2 = c(rnorm(200, mean=350, sd=100), rnorm(200, mean=500, sd=100))
    df_exp = data.frame(var1=log2(g1 + 1) , var2=log2(g2 + 1))
    
    
    bivariate_plot <- function(df, var1, var2, density = T, box = F) {
      require(ggplot2)
      require(cowplot)
      scatter = ggplot(df, aes_string(var1, var2)) +
        geom_point(alpha=.8, color = "red")
    
      plot1 = ggplot(df, aes_string(var1)) + geom_density(alpha=.5, fill = "red") 
      plot1 = plot1 + ylab("G1 density")
    
      plot2 = ggplot(df, aes_string(var2)) + geom_density(alpha=.5, fill = "red") 
      plot2 = plot2 + ylab("G2 density")
    
      # Avoid displaying duplicated legend
      plot1 = plot1 + theme(legend.position="none")
      plot2 = plot2 + theme(legend.position="none")
    
      # Homogenize scale of shared axes
      min_exp = min(df[[var1]], df[[var2]]) - 0.01
      max_exp = max(df[[var1]], df[[var2]]) + 0.01
      scatter = scatter + ylim(min_exp, max_exp)
      scatter = scatter + xlim(min_exp, max_exp)
      plot1 = plot1 + xlim(min_exp, max_exp)
      plot2 = plot2 + xlim(min_exp, max_exp)
      plot1 = plot1 + ylim(0, 2)
      plot2 = plot2 + ylim(0, 2)
    
      # Flip axis of gg_dist_g2
      plot2 = plot2 + coord_flip()
    
      # Remove some duplicate axes
      plot1 = plot1 + theme(axis.title.x=element_blank(),
                            axis.text=element_blank(),
                            axis.line=element_blank(),
                            axis.ticks=element_blank())
    
      plot2 = plot2 + theme(axis.title.y=element_blank(),
                            axis.text=element_blank(),
                            axis.line=element_blank(),
                            axis.ticks=element_blank())
    
      # Modify margin c(top, right, bottom, left) to reduce the distance between plots
      #and align G1 density with the scatterplot
      plot1 = plot1 + theme(plot.margin = unit(c(0.5, 0, 0, 0.7), "cm"))
      scatter = scatter + theme(plot.margin = unit(c(0, 0, 0.5, 0.5), "cm"))
      plot2 = plot2 + theme(plot.margin = unit(c(0, 0.5, 0.5, 0), "cm"))
    
      # Combine all plots together and crush graph density with rel_heights
      perfect = plot_grid(plot1, NULL, scatter, plot2,
                          ncol = 2, rel_widths = c(3, 1), rel_heights = c(1, 3))
    
      print(perfect)
    }
    
    bivariate_plot(df = df_exp, var1 = "var1", var2 = "var2")
    

    enter image description here