Search code examples
rshinyxgboostkill-processcancel-button

How to terminate a currently long running XGBoost CV process in R shiny using a Button?


I would like to implement a cross-validation model in R Shiny using the xgboost model and the xgb.cv() function.

Taking into account that this process/function will take a couple of hours to be completed, I would like to add a "Cancel" button which will be implemented with a stop process function in order the user to terminate the process at any time.

Could you please advise me on how to proceed?

Server Code:

server <-  function(input, output, session) {

  observeEvent(input$ML_Submit_Button, {
    shinyjs::hide("ML_Submit_Button")
    shinyjs::show("ML_Stop_Button")
    
    xgb_gs_cv_regression(
      xgb_train = values$xgb_train,
      subsample_choice = values$subsample_slider_seq,
      colsample_bytree_choice = values$colsample_bytree_slider_seq,
      max_depth_choice = values$max_depth_slider_seq,
      min_child_weight_choice = values$min_child_weight_slider_seq,
      eta_choice = values$eta_slider_seq,
      n_rounds_choice = values$n_rounds_slider_seq,
      n_fold_choice = values$n_fold_slider_seq
    )
    
    shinyjs::hide("ML_Stop_Button")
    shinyjs::show("ML_Submit_Button")
    
    
  })

}

XGB CV Function Code:

xgb_gs_cv_regression <- function(xgb_train,
           subsample_choice,
           colsample_bytree_choice,
           max_depth_choice,
           min_child_weight_choice,
           eta_choice,
           n_rounds_choice,
           n_fold_choice) {

searchGridSubCol <- expand.grid(
  subsample = subsample_choice,
  colsample_bytree = colsample_bytree_choice,
  max_depth = max_depth_choice,
  min_child_weight = min_child_weight_choice,
  eta = eta_choice,
  n_rounds = n_rounds_choice,
  n_fold = n_fold_choice
)

rmseErrorsHyperparameters <- apply(searchGridSubCol, 1,
                                   
                                   function(parameterList) {
                                     #Extract Parameters to test
                                     currentSubsampleRate <-
                                       parameterList[["subsample"]]
                                     currentColsampleRate <-
                                       parameterList[["colsample_bytree"]]
                                     currentDepth <-
                                       parameterList[["max_depth"]]
                                     currentEta <-
                                       parameterList[["eta"]]
                                     currentMinChildWeight <-
                                       parameterList[["min_child_weight"]]
                                     currentNRounds <-
                                       parameterList[["n_rounds"]]
                                     currentNFold <-
                                       parameterList[["n_fold"]]
                                     
                                     xgboostModelCV <-
                                       xgb.cv(
                                         objective = "reg:squarederror",
                                         data =  xgb_train,
                                         booster = "gbtree",
                                         showsd = TRUE,
                                         #metrics = "rmse",
                                         verbose = TRUE,
                                         print_every_n = 10,
                                         early_stopping_rounds = 10,
                                         eval_metric = "rmse",
                                         "nrounds" = currentNRounds,
                                         "nfold" = currentNFold,
                                         "max_depth" = currentDepth,
                                         "eta" = currentEta,
                                         "subsample" = currentSubsampleRate,
                                         "colsample_bytree" = currentColsampleRate,
                                         "min_child_weight" = currentMinChildWeight
                                       )
                                     
                                     xgb_cv_xvalidationScores <-
                                       xgboostModelCV$evaluation_log
                                     
                                     test_rmse <-
                                       tail(xgb_cv_xvalidationScores$test_rmse_mean, 1)
                                     train_rmse <-
                                       tail(xgb_cv_xvalidationScores$train_rmse_mean, 1)
                                     
                                     gs_results_output <-
                                       c(
                                         test_rmse,
                                         train_rmse,
                                         currentSubsampleRate,
                                         currentColsampleRate,
                                         currentDepth,
                                         currentEta,
                                         currentMinChildWeight,
                                         currentNRounds,
                                         currentNFold
                                       )
                                     
                                     return(gs_results_output)
                                     
                                   })

gs_results_varnames <-
  c(
    "TestRMSE",
    "TrainRMSE",
    "SubSampRate",
    "ColSampRate",
    "Depth",
    "eta",
    "currentMinChildWeight",
    "nrounds",
    "nfold"
  )
t_rmseErrorsHyperparameters <-
  as.data.frame(t(rmseErrorsHyperparameters))
names(t_rmseErrorsHyperparameters) <- gs_results_varnames

return(t_rmseErrorsHyperparameters) 

}

Solution

  • You can realize the desired pattern via callr::r_bg(), which is based on processx::process().

    r_bg() runs R functions in a background R process - which can be cancled via its kill() method.

    Actually it doesn't matter which function you are running - so I simplified the example.

    Please check the following:

    library(shiny)
    library(callr)
    library(shinyjs)
    
    long_running_function <- function(x){
      Sys.sleep(x)
      return(sprintf("I slept %s seconds", x))
    }
    
    ui <- fluidPage(
      useShinyjs(),
      actionButton("runbgp", "Run bg process"),
      actionButton("cancelbgp", "Cancel bg process")
    )
    
    server <- function(input, output, session) {
      rv <- reactiveValues(bg_process = NULL)
      
      observeEvent(input$runbgp, {
        disable("runbgp")
        enable("cancelbgp")
        rv$bg_process <- r_bg(long_running_function, args = list(5), stdout = "|", stderr = "2>&1")
      })
      
      observeEvent(input$cancelbgp, {
        enable("runbgp")
        disable("cancelbgp")
        cat(paste("Killing process - PID:", rv$bg_process$get_pid(), "\n"))
        rv$bg_process$kill()
      })
      
      observe({
        invalidateLater(1000)
        req(rv$bg_process)
        if(rv$bg_process$poll_io(0)[["process"]] == "ready") {
          enable("runbgp")
          disable("cancelbgp")
          print(rv$bg_process$get_result())
          rv$bg_process <- NULL
        }
      })
    }
    
    shinyApp(ui, server)