Search code examples
rh2oleaderboardautoml

how to print variable importance of all the models in the leaderboard of h2o.automl in r


I am using automl() function of H2o package in R for regression.

Consider I am using the name "aml" for building models.

aml <- h2o.automl(x=x, y=y, training_frame = train_set,
              max_models = 20, seed = 1,
              keep_cross_validation_predictions = TRUE)

The leaderboard of automl() shows the top performed models. I am able to print the importance of the predictors through h2o.varimp() function and plot a graph for the same using h2o.varimp_plot() function for only the leader model (the best model given by automl function).

h2o.varimp(aml@leader)
h2o.varimp_plot(aml@leader)

Is there any way to print the variable importance of the predictors for all the models in the leaderboard and plot a graph using the above two functions?


Solution

  • Stacked Ensembles (usually the leader model) does not yet support variable importance (JIRA here). However the variable importance for rest of the models can be retrieved in a loop over the model ids in the leaderboard. See R code below.

    library(h2o)
    h2o.init()
    
    # Import a sample binary outcome train/test set into H2O
    train <- h2o.importFile("https://s3.amazonaws.com/erin-data/higgs/higgs_train_10k.csv")
    
    # Identify predictors and response
    y <- "response"
    x <- setdiff(names(train), y)
    
    # For binary classification, response should be a factor
    train[,y] <- as.factor(train[,y])
    
    # Run AutoML for 10 models
    aml <- h2o.automl(x = x, y = y,
                      training_frame = train,
                      max_models = 10,
                      seed = 1)
    
    # View the AutoML Leaderboard
    lb <- aml@leaderboard
    print(lb, n = nrow(lb))
    
    # Get model ids for all models in the AutoML Leaderboard
    model_ids <- as.data.frame(lb$model_id)[,1]
    
    # View variable importance for all the models (besides Stacked Ensemble)
    for (model_id in model_ids) {
      print(model_id)
      m <- h2o.getModel(model_id)
      h2o.varimp(m)
      h2o.varimp_plot(m)
    }