Search code examples
rshinyplotlylinear-regressionggplotly

How to toggle traces with a button?


My R Shiny app currently generates dynamic plots and organizes them into categories. When input$regression is true, all output$continuous_plots are rerendered to include a regression line. These plots update dynamically based on changes in input$group_var, input$regression (a boolean), and input$group_var_values. Instead of rerendering the plots each time to include a regression line or revert to the original scatter plot, I'd like to dynamically add traces that represent the regression line when input$regression is true. When input$regression is false, the plots should reset to the original scatter plot. Is this possible with my code?

output$binary_plots <- renderUI({
      binary_x_vars <- Filter(function(x) get_variable_type(global_dat[[x]]) == "binary", input$x_sel)  # Filter binary variables
      plot_output_list <- lapply(binary_x_vars, function(x_var) {
        plotname <- paste("plot", x_var, sep = "_")
        plot_output <- plotlyOutput(plotname, height = '300px',width = '100%')  # Create plot output for each continuous variable
        div(style = "margin-bottom: 10px;", plot_output)
      })
      
      do.call(tagList, plot_output_list)  # Combine plot outputs into a tag list
    })
    
    output$continuous_plots <-  renderUI({
      continuous_x_vars <- Filter(function(x) get_variable_type(global_dat[[x]]) == "continuous", input$x_sel)  # Filter continuous variables
      plot_output_list <- lapply(continuous_x_vars, function(x_var) {
        plotname <- paste("plot", x_var, sep = "_")
        plot_output <- plotlyOutput(plotname, height = '300px',width = '100%')  # Create plot output for each continuous variable
        div(style = "margin-bottom: 20px;", plot_output)
      })
      
      do.call(tagList, plot_output_list)  # Combine plot outputs into a tag list
    })
    
    output$string_plots <- renderUI({
      string_x_vars <- Filter(function(x) get_variable_type(global_dat[[x]]) == "string", input$x_sel)  # Filter string variables
      plot_output_list <- lapply(string_x_vars, function(x_var) {
        plotname <- paste("plot", x_var, sep = "_")
        plot_output <- plotlyOutput(plotname, height = '300px',width = '100%')  # Create plot output for each continuous variable
        div(style = "margin-bottom: 5px;", plot_output)
      })
      
      do.call(tagList, plot_output_list)  # Combine plot outputs into a tag list
    })
    
    
    
    
    observe({
      req(input$y_sel, input$x_sel)  # Require selection of y and x variables

      lapply(input$x_sel, function(x_var) {
        output[[paste("plot", x_var, sep = "_")]] <- renderPlotly({
          filtered_dat <- global_dat

          # Apply filter based on selected group values
          if (!is.null(input$group_var_values) && length(input$group_var_values) > 0) {
            filtered_dat <- filtered_dat %>% filter(filtered_dat[[input$group_var]] %in% as.list(input$group_var_values))
          }

          # Define plot name for this iteration
          plot_name <- glue::glue('{input$y_sel}_vs_{x_var}')

          # Reset input values so the donwload csv names are unique to every input$y_sel and input$x_sel combination
          isolate({
            updateSelectInput(session, "y_sel", selected = NULL)
            updateSelectInput(session, "x_sel", selected = NULL)
          })

          # Generate plot
          p <- if (is.factor(filtered_dat[[x_var]]) || is.factor(filtered_dat[[input$y_sel]])) {
            if (input$group_var == 'None selected') {
              ggplot(filtered_dat, aes_string(x = x_var, y = input$y_sel)) +
                geom_boxplot() +
                ggtitle(paste("Boxplot of", x_var, "vs", input$y_sel)) +
                theme_bw()
            } else {
              ggplot(filtered_dat, aes_string(x = x_var, y = input$y_sel, color = input$group_var,customdata = 'row_id')) +
                geom_boxplot() +
                ggtitle(paste("Boxplot of", x_var, "vs", input$y_sel, "with Group Coloring")) +
                theme_bw()
            }
          } else {
            if (input$group_var == 'None selected') {
              ggplot(filtered_dat, aes_string(x = x_var, y = input$y_sel)) +
                geom_point() +
                {
                  if (input$regression)
                    stat_smooth(
                      method = "lm",se = F,
                      linetype = "dashed",
                      color = "red"
                    )
                } +
                ggtitle(paste("Scatter Plot of", x_var, "vs", input$y_sel)) +
                theme_bw()
            } else {
              ggplot(filtered_dat, aes_string(x = x_var, y = input$y_sel, color = as.character(input$group_var),customdata = 'row_id')) +
                geom_point(alpha = .5) +
                {
                  if (input$regression)
                    stat_smooth(method = "lm", se = F,linetype = 'dashed')
                } +
                ggtitle(paste("Scatter Plot of", x_var, "vs", input$y_sel, "with Group Coloring")) +
                theme_bw()
            }
          }

          # Convert ggplot to plotly
          p <- ggplotly(p, source = "plot1") %>%  layout(clickmode = "event+select", dragmode = 'select')

          # Configure the plot with the download button
          p <- config(
            p,
            scrollZoom = TRUE,
            modeBarButtonsToAdd = list(
              list(button_fullscreen(), button_download(data = p[["x"]][["visdat"]][[p[["x"]][["cur_data"]]]](), plot_name = plot_name))
            ),
            modeBarButtonsToRemove = c("toImage", "hoverClosest", "hoverCompare"),
            displaylogo = FALSE
          )

          # Return the plot
          p %>% toWebGL()
        })
      })
    })

This is the closest example I have gotten to that will generate regresssions depending on the color group, but I don't know how to delete the traces and just keep the scatter plot if the checkmark input value is false:

library(shiny)
library(plotly)


# Generate 100,000 observations from 2 correlated random variables
s <- matrix(c(1, 0.5, 0.5, 1), 2, 2)
d <- MASS::mvrnorm(300, mu = c(0, 0), Sigma = s)
d <- setNames(as.data.frame(d), c("x", "y"))

# Introduce a grouping variable
set.seed(123)  # for reproducibility
d$group <- sample(letters[1:3], nrow(d), replace = TRUE)

# fit separate linear models for each group
models <- lapply(unique(d$group), function(g) {
  lm(y ~ x, data = subset(d, group == g))
})

# generate y predictions over a grid of 10 x values for each group
dpred <- lapply(models, function(model) {
  data.frame(
    x = seq(min(d$x), max(d$x), length.out = 10),
    yhat = predict(model, newdata = data.frame(x = seq(min(d$x), max(d$x), length.out = 10)))
  )
})

# Define colors for each group
group_colors <- c("red", "blue", "green")

ui <- fluidPage(
  plotlyOutput("scatterplot"),
  checkboxInput(
    "smooth", 
    label = "Overlay fitted lines?", 
    value = FALSE
  )
)

server <- function(input, output, session) {
  
  added_traces <- list()  # Initialize list to store added traces indices
  
  output$scatterplot <- renderPlotly({
    p <- plot_ly()  # Initialize plot object
    
    # Add markers for each group
    for (i in seq_along(models)) {
      group_data <- subset(d, group == unique(d$group)[i])
      p <- p %>% add_markers(
        data = group_data,
        x = ~x, y = ~y,
        color = I(group_colors[i]),
        alpha = 0.5
      )
    }
    
    p %>% toWebGL()
  })
  
  observeEvent(input$smooth, {
    if (input$smooth) {
      # Add lines for each group's regression line
      for (i in seq_along(dpred)) {
        trace <- plotlyProxy("scatterplot", session) %>%
          plotlyProxyInvoke(
            "addTraces",
            list(
              x = dpred[[i]]$x,
              y = dpred[[i]]$yhat,
              type = "scattergl",
              mode = "lines",
              line = list(color = group_colors[i])
            )
          )
        added_traces <- c(added_traces, trace)  # Store the index of added trace
      }
    } else {
      # Remove all traces if checkbox is unchecked
      plotlyProxy("scatterplot", session) %>%
        plotlyProxyInvoke("deleteTraces",1)
      
    }
  }, ignoreInit = TRUE)
  
}

shinyApp(ui, server)

Solution

  • Below is an example where the button can be used for switching the traces on and off.

    enter image description here

    library(shiny)
    library(plotly)
    library(htmlwidgets)
    
    js <- "function(el, x, data){
      var id = el.getAttribute('id');
      $(document).on('shiny:inputchanged', function(event) {
        if (event.name === 'smooth') {
          var out = [];
          d3.select('#' + id + ' g.legend').selectAll('.traces').each(function(){
            var trace = d3.select(this)._groups[0][0].__data__[0].trace;
            out.push([name=trace.name, index=trace.index, mode=trace.mode]);
          });
          Shiny.setInputValue('TraceInfo', out);
        }
      });
    }"
    
    
    # Generate 100,000 observations from 2 correlated random variables
    s <- matrix(c(1, 0.5, 0.5, 1), 2, 2)
    d <- MASS::mvrnorm(300, mu = c(0, 0), Sigma = s)
    d <- setNames(as.data.frame(d), c("x", "y"))
    
    # Introduce a grouping variable
    set.seed(123)  # for reproducibility
    d$group <- sample(letters[1:3], nrow(d), replace = TRUE)
    
    # fit separate linear models for each group
    models <- lapply(unique(d$group), function(g) {
      lm(y ~ x, data = subset(d, group == g))
    })
    
    # generate y predictions over a grid of 10 x values for each group
    dpred <- lapply(models, function(model) {
      data.frame(
        x = seq(min(d$x), max(d$x), length.out = 10),
        yhat = predict(model, newdata = data.frame(x = seq(min(d$x), max(d$x), length.out = 10)))
      )
    })
    
    # Define colors for each group
    group_colors <- c("red", "blue", "green")
    
    ui <- fluidPage(
      plotlyOutput("scatterplot"),
      checkboxInput(
        "smooth", 
        label = "Overlay fitted lines?", 
        value = FALSE
      ),
      tags$head(tags$script(src = "https://cdnjs.cloudflare.com/ajax/libs/d3/7.3.0/d3.min.js"))
    )
    
    server <- function(input, output, session) {
      
      added_traces <- list()  # Initialize list to store added traces indices
      
      output$scatterplot <- renderPlotly({
        p <- plot_ly()  # Initialize plot object
        
        # Add markers for each group
        for (i in seq_along(models)) {
          group_data <- subset(d, group == unique(d$group)[i])
          p <- p %>% add_markers(
            data = group_data,
            x = ~x, y = ~y,
            color = I(group_colors[i]),
            alpha = 0.5
          )
        }
        
        p %>% toWebGL() %>% onRender(js) 
      })
      
      observeEvent(input$smooth, {
        if (input$smooth) {
          # Add lines for each group's regression line
          for (i in seq_along(dpred)) {
            req(input$TraceInfo)
            trace <- plotlyProxy("scatterplot", session) %>%
              plotlyProxyInvoke(
                "addTraces",
                list(
                  x = dpred[[i]]$x,
                  y = dpred[[i]]$yhat,
                  type = "scattergl",
                  mode = "lines",
                  line = list(color = group_colors[i]),
                  name = input$TraceInfo
                )
              )
            added_traces <- c(added_traces, trace)  # Store the index of added trace
          }
        } else {
          # Remove all traces if checkbox is unchecked
          req(input$TraceInfo)
          traces <- matrix(input$TraceInfo, ncol = 3, byrow = TRUE)
          indices <- as.integer(traces[traces[, 3] == "lines", 2])
          plotlyProxy("scatterplot", session) %>%
            plotlyProxyInvoke("deleteTraces", indices)
          
        }
      }, ignoreInit = TRUE)
      
    }
    
    shinyApp(ui, server)