Search code examples
rggplot2gbm

Plot multiple runs of gbm.step() in one plot


I have run a set of 100 runs of gbm.step to account for the stochasticity component of the analysis, which provides slightly different results for each BRT model, and consequently estimate the range (minimum and maximum values) for the fitted functions. I want to plot these results in a plot like this one:

plot 100 gbm runs

Reproducible example:

data(iris)
mod2<-list()
for(i in 1:100){
  mod2[[i]]<-gbm.step(data=iris, gbm.x = 3:4, gbm.y = 1,
                   family = "gaussian", tree.complexity = 4,
                   learning.rate = 0.01, bag.fraction = 0.5, tolerance.method = "fixed")
}
gbm.plot(mod2[[1]],common.scale=F,smooth=T,write.title = FALSE, plot.layout = c(1,2))

This is a plot of one of the 100 models. I want one like the image above.

Is there any function that takes my 100 models and plot it like this? If not, what is the best approach to take in ggplot?


Solution

  • We can try something like this:

    data(iris)
    mod2<-list()
    for(i in 1:20){
      mod2[[i]]<-gbm.step(data=iris,
      gbm.x = 3:4, gbm.y = 1,
      family = "gaussian", tree.complexity = 4,
      learning.rate = 0.01, bag.fraction = 0.5, tolerance.method = "fixed")
    }
    

    And we take some relevant parts out of gbm.plot, to make a pretty primitive function for 1 predictor, to get the x and y values:

    getVar = function(gbm.object,predictor_of_interest){
    gbm.call <- gbm.object$gbm.call
    gbm.x <- gbm.call$gbm.x
    pred.names <- gbm.call$predictor.names
    response.name <- gbm.call$response.name
    data <- gbm.call$dataframe
    k <- match(predictor_of_interest, pred.names)
    var.name <- gbm.call$predictor.names[k]
    pred.data <- data[, gbm.call$gbm.x[k]]
    response.matrix <- gbm::plot.gbm(gbm.object, k, return.grid = TRUE)
    data.frame(predictors = response.matrix[, 1],
    responses = response.matrix[, 2] - mean(response.matrix[,2])
    )
    }
    

    Then we iterate through the list of models, collect data:

    library(ggplot2)
    da = lapply(1:length(mod2),function(i){
    data.frame(getVar(mod2[[i]],"Petal.Length"),model=i)})
    da = do.call(rbind,da)
    

    We can plot all the lines, specifying group in the aes:

    ggplot(da,aes(x=predictors,y=responses,group=model)) + 
    geom_line(alpha=0.4) + theme_bw()
    

    enter image description here

    Or min,max,mean as you mentioned using stat_summary, without the group:

    ggplot(da,aes(x=predictors,y=responses)) +
    stat_summary(geom="ribbon",fun.ymin="min",fun.ymax="max",alpha=0.3) +
    stat_summary(geom="line",fun.y="mean")+theme_bw()
    

    enter image description here