I have run a set of 100 runs of gbm.step
to account for the stochasticity component of the analysis, which provides slightly different results for each BRT model, and consequently estimate the range (minimum and maximum values) for the fitted functions.
I want to plot these results in a plot like this one:
Reproducible example:
data(iris)
mod2<-list()
for(i in 1:100){
mod2[[i]]<-gbm.step(data=iris, gbm.x = 3:4, gbm.y = 1,
family = "gaussian", tree.complexity = 4,
learning.rate = 0.01, bag.fraction = 0.5, tolerance.method = "fixed")
}
gbm.plot(mod2[[1]],common.scale=F,smooth=T,write.title = FALSE, plot.layout = c(1,2))
This is a plot of one of the 100 models. I want one like the image above.
Is there any function that takes my 100 models and plot it like this? If not, what is the best approach to take in ggplot?
We can try something like this:
data(iris)
mod2<-list()
for(i in 1:20){
mod2[[i]]<-gbm.step(data=iris,
gbm.x = 3:4, gbm.y = 1,
family = "gaussian", tree.complexity = 4,
learning.rate = 0.01, bag.fraction = 0.5, tolerance.method = "fixed")
}
And we take some relevant parts out of gbm.plot, to make a pretty primitive function for 1 predictor, to get the x and y values:
getVar = function(gbm.object,predictor_of_interest){
gbm.call <- gbm.object$gbm.call
gbm.x <- gbm.call$gbm.x
pred.names <- gbm.call$predictor.names
response.name <- gbm.call$response.name
data <- gbm.call$dataframe
k <- match(predictor_of_interest, pred.names)
var.name <- gbm.call$predictor.names[k]
pred.data <- data[, gbm.call$gbm.x[k]]
response.matrix <- gbm::plot.gbm(gbm.object, k, return.grid = TRUE)
data.frame(predictors = response.matrix[, 1],
responses = response.matrix[, 2] - mean(response.matrix[,2])
)
}
Then we iterate through the list of models, collect data:
library(ggplot2)
da = lapply(1:length(mod2),function(i){
data.frame(getVar(mod2[[i]],"Petal.Length"),model=i)})
da = do.call(rbind,da)
We can plot all the lines, specifying group
in the aes
:
ggplot(da,aes(x=predictors,y=responses,group=model)) +
geom_line(alpha=0.4) + theme_bw()
Or min,max,mean as you mentioned using stat_summary, without the group:
ggplot(da,aes(x=predictors,y=responses)) +
stat_summary(geom="ribbon",fun.ymin="min",fun.ymax="max",alpha=0.3) +
stat_summary(geom="line",fun.y="mean")+theme_bw()