Search code examples
pythonrpartial

Cross validation of the Partial Dependence Plots


In one of my robustness test, I want to perform cross validation of the partial dependence plot but I don't know where to start. My model is regression tree and I had partial dependence plots based on the whole dataset. My questions are:

  1. If I randomly divide the dataset in 10 random samples, and calculate the partial dependence of variable X on Y based on each random sample, how can I average results of 10 samples to come up with one plot? I cannot find any built-in function in python or R to do that?

  2. The same task as above, however, I would like to draw partial dependence plot of 2-way interaction, for example, variables X1 and X2 on Y?

Thank you.


Solution

  • Further to my answer in the comments, if you wanted to look at the variance of the ice curves, you could bootstrap them like this:

    library(pdp)
    library(randomForest)
    library(ICEbox)
    data(boston)
    X <- as.data.frame(model.matrix(cmedv ~ ., data=boston)[,-1])
    y <- model.response(model.frame(cmedv ~ ., data=boston))
    boston.rf <- randomForest(x=X, y=y)
    bice <- ice(boston.rf, X=X, predictor = "lstat") 
    
    res <- NULL
    for(i in 1:1000){
      inds <- sample(1:nrow(bice$ice_curves), 
                     nrow(bice$ice_curves), 
                     replace=TRUE)
      res <- rbind(res, colMeans(bice$ice_curve[inds, ]))
    }
    
    out <- data.frame(
      fit = colMeans(bice$ice_curves), 
      lwr = apply(res, 2, quantile, .025),
      upr = apply(res, 2, quantile, .975), 
      x=bice$gridpts
    )
    
    library(ggplot2)
    ggplot(out, aes(x=x, y=fit, ymin=lwr, ymax=upr)) + 
      geom_ribbon(alpha=.25) + 
      geom_line() + 
      theme_bw() + 
      labs(x="lstat", y="Prediction")
    

    enter image description here

    Or, you could look at the different quantiles of the ice plots for each evaluation point.

    tmp <- t(apply(bice$ice_curves, 
                 2, 
                 quantile, c(0, .025, .05, .1, .25, .5, .75, .9, .95, .975, 1)))
    
    head(tmp)
    tmp <- as.data.frame(tmp)
    names(tmp) <- c("l1", "l2", "l3", "l4", "l5", 
                    "med", "u1", "u2", "u3", "u4", "u5")
    
    tmp$x <- bice$gridpts
    
    ggplot(tmp, aes(x=x, y=med)) + 
      geom_ribbon(aes(ymin=l1, ymax=u1), alpha=.2) + 
      geom_ribbon(aes(ymin=l2, ymax=u2), alpha=.2) + 
      geom_ribbon(aes(ymin=l3, ymax=u3), alpha=.2) + 
      geom_ribbon(aes(ymin=l4, ymax=u4), alpha=.2) + 
      geom_ribbon(aes(ymin=l5, ymax=u5), alpha=.2) + 
      geom_line() + 
      theme_bw() + 
      labs(x="lstat", y="Prediction")
    

    enter image description here