I'm using the Caret package in R to train a model by the method called 'xgbTree' in R.
After plotting the trained model as shown the picture below: the tuning parameter namely 'eta' = 0.2 is not what I want as I also have eta = 0.1 as tuning parameter defined in expand.grid before training the model, which is the best tune. So I want to change the eta = 0.2 in the plot to the scenario that eta = 0.1 in the plot function. How could I do it? Thank you.
set.seed(100) # For reproducibility
xgb_trcontrol = trainControl(
method = "cv",
#repeats = 2,
number = 10,
#search = 'random',
allowParallel = TRUE,
verboseIter = FALSE,
returnData = TRUE
)
xgbGrid <- expand.grid(nrounds = c(100,200,1000), # this is n_estimators in the python code above
max_depth = c(6:8),
colsample_bytree = c(0.6,0.7),
## The values below are default values in the sklearn-api.
eta = c(0.1,0.2),
gamma=0,
min_child_weight = c(5:8),
subsample = c(0.6,0.7,0.8,0.9)
)
set.seed(0)
xgb_model8 = train(
x, y_train,
trControl = xgb_trcontrol,
tuneGrid = xgbGrid,
method = "xgbTree"
)
What happens is that the plotting device plots over all values of your grid, and the last one to appear is eta=0.2. For example:
xgb_trcontrol = trainControl(method = "cv", number = 3,returnData = TRUE)
xgbGrid <- expand.grid(nrounds = c(100,200,1000),
max_depth = c(6:8),
colsample_bytree = c(0.6,0.7),
eta = c(0.1,0.2),
gamma=0,
min_child_weight = c(5:8),
subsample = c(0.6,0.7,0.8,0.9)
)
set.seed(0)
x = mtcars[,-1]
y_train = mtcars[,1]
xgb_model8 = train(
x, y_train,
trControl = xgb_trcontrol,
tuneGrid = xgbGrid,
method = "xgbTree"
)
You can save your plots like this:
pdf("plots.pdf")
plot(xgb_model8,metric="RMSE")
dev.off()
Or if you want to plot a specific parameter, for example eta = 0.2, you would also need to fix the colsample_bytree
, otherwise it's too many parameters:
library(ggplot2)
ggplot(subset(xgb_model8$results
,eta==0.1 & colsample_bytree==0.6),
aes(x=min_child_weight,y=RMSE,group=factor(subsample),col=factor(subsample))) +
geom_line() + geom_point() + facet_grid(nrounds~max_depth)