Search code examples
rplotr-caret

Change tuning parameters shown in the plot created by Caret in R


I'm using the Caret package in R to train a model by the method called 'xgbTree' in R.

After plotting the trained model as shown the picture below: the tuning parameter namely 'eta' = 0.2 is not what I want as I also have eta = 0.1 as tuning parameter defined in expand.grid before training the model, which is the best tune. So I want to change the eta = 0.2 in the plot to the scenario that eta = 0.1 in the plot function. How could I do it? Thank you.

enter image description here

set.seed(100)  # For reproducibility

xgb_trcontrol = trainControl(
method = "cv",
#repeats = 2,
number = 10,  
#search = 'random',
allowParallel = TRUE,
verboseIter = FALSE,
returnData = TRUE
)


xgbGrid <- expand.grid(nrounds = c(100,200,1000),  # this is n_estimators in the python code above
                   max_depth = c(6:8),
                   colsample_bytree = c(0.6,0.7),
                   ## The values below are default values in the sklearn-api. 
                   eta = c(0.1,0.2),
                   gamma=0,
                   min_child_weight = c(5:8),
                   subsample = c(0.6,0.7,0.8,0.9)
)


set.seed(0) 
xgb_model8 = train(
x, y_train,  
trControl = xgb_trcontrol,
tuneGrid = xgbGrid,
method = "xgbTree"
)

Solution

  • What happens is that the plotting device plots over all values of your grid, and the last one to appear is eta=0.2. For example:

    xgb_trcontrol = trainControl(method = "cv", number = 3,returnData = TRUE)
    
    xgbGrid <- expand.grid(nrounds = c(100,200,1000),  
                       max_depth = c(6:8),
                       colsample_bytree = c(0.6,0.7), 
                       eta = c(0.1,0.2),
                       gamma=0,
                       min_child_weight = c(5:8),
                       subsample = c(0.6,0.7,0.8,0.9)
    )
    
    set.seed(0)
    
    x = mtcars[,-1]
    y_train = mtcars[,1]
    
    xgb_model8 = train(
    x, y_train,  
    trControl = xgb_trcontrol,
    tuneGrid = xgbGrid,
    method = "xgbTree"
    )
    

    You can save your plots like this:

    pdf("plots.pdf")
    plot(xgb_model8,metric="RMSE")
    dev.off()
    

    Or if you want to plot a specific parameter, for example eta = 0.2, you would also need to fix the colsample_bytree, otherwise it's too many parameters:

    library(ggplot2)
    
    ggplot(subset(xgb_model8$results
    ,eta==0.1 & colsample_bytree==0.6),
    aes(x=min_child_weight,y=RMSE,group=factor(subsample),col=factor(subsample))) + 
    geom_line() + geom_point() + facet_grid(nrounds~max_depth)
    

    enter image description here