In one of my robustness test, I want to perform cross validation of the partial dependence plot but I don't know where to start. My model is regression tree and I had partial dependence plots based on the whole dataset. My questions are:
If I randomly divide the dataset in 10 random samples, and calculate the partial dependence of variable X on Y based on each random sample, how can I average results of 10 samples to come up with one plot? I cannot find any built-in function in python or R to do that?
The same task as above, however, I would like to draw partial dependence plot of 2-way interaction, for example, variables X1 and X2 on Y?
Thank you.
Further to my answer in the comments, if you wanted to look at the variance of the ice curves, you could bootstrap them like this:
library(pdp)
library(randomForest)
library(ICEbox)
data(boston)
X <- as.data.frame(model.matrix(cmedv ~ ., data=boston)[,-1])
y <- model.response(model.frame(cmedv ~ ., data=boston))
boston.rf <- randomForest(x=X, y=y)
bice <- ice(boston.rf, X=X, predictor = "lstat")
res <- NULL
for(i in 1:1000){
inds <- sample(1:nrow(bice$ice_curves),
nrow(bice$ice_curves),
replace=TRUE)
res <- rbind(res, colMeans(bice$ice_curve[inds, ]))
}
out <- data.frame(
fit = colMeans(bice$ice_curves),
lwr = apply(res, 2, quantile, .025),
upr = apply(res, 2, quantile, .975),
x=bice$gridpts
)
library(ggplot2)
ggplot(out, aes(x=x, y=fit, ymin=lwr, ymax=upr)) +
geom_ribbon(alpha=.25) +
geom_line() +
theme_bw() +
labs(x="lstat", y="Prediction")
Or, you could look at the different quantiles of the ice plots for each evaluation point.
tmp <- t(apply(bice$ice_curves,
2,
quantile, c(0, .025, .05, .1, .25, .5, .75, .9, .95, .975, 1)))
head(tmp)
tmp <- as.data.frame(tmp)
names(tmp) <- c("l1", "l2", "l3", "l4", "l5",
"med", "u1", "u2", "u3", "u4", "u5")
tmp$x <- bice$gridpts
ggplot(tmp, aes(x=x, y=med)) +
geom_ribbon(aes(ymin=l1, ymax=u1), alpha=.2) +
geom_ribbon(aes(ymin=l2, ymax=u2), alpha=.2) +
geom_ribbon(aes(ymin=l3, ymax=u3), alpha=.2) +
geom_ribbon(aes(ymin=l4, ymax=u4), alpha=.2) +
geom_ribbon(aes(ymin=l5, ymax=u5), alpha=.2) +
geom_line() +
theme_bw() +
labs(x="lstat", y="Prediction")