Search code examples
h2osparkling-water

Create partial dependence plot using H2O in spark?


I am trying to create partial dependent plot using the following code

rf_pdp = rf_model .partial_plot(data = htest, cols = ['var1', 'var2', 'var3'], plot=True)
rf_pdp 

it runs without error and generate a table with mean_response, stddev_response, std_error_mean_response for each variables. BUT there is no plot. Is that because I run the code in Spark environment?

I am running H2O cluster version: 3.20.0.7 using Sparkling Water under Qubole

%pyspark
# start h2o
from pysparkling import *
import h2o
hc = H2OContext.getOrCreate(spark)

# clean up the cluster just in case
h2o.remove_all()

# import data
iris = h2o.import_file("http://h2o-public-test-data.s3.amazonaws.com/smalldata/iris/iris_wheader.csv")

# convert response column to a factor
iris['class'] = iris['class'].asfactor()

# set the predictor names
predictors = iris.columns[:-1]

# split into train and validation sets
train, valid = iris.split_frame(ratios = [.8], seed = 1234)

# random forest
from h2o.estimators.random_forest import H2ORandomForestEstimator

rf_model = H2ORandomForestEstimator(
                score_each_iteration=True,
                score_tree_interval = 5,
                max_runtime_secs = 1800,
                stopping_metric = 'logloss', 
                stopping_tolerance=0.001,
                stopping_rounds= 3,
                sample_rate = 0.7, 
                col_sample_rate_per_tree = 0.7,                
                ntrees=1000,
                balance_classes=False,
                seed=456,
                nfolds=5
                )

rf_model.train(x=predictors, y ='class', training_frame=train)

# plot
rf_model.plot()

Solution

  • Here is the solution for plotting

    import matplotlib
    import matplotlib.pyplot as plt
    import io
    
    matplotlib.use('agg')
    
    def show(p):
        img = io.StringIO()
        p.savefig(img, format='svg')
        img.seek(0)
        print("%html <div>" + img.getvalue() + "</div>")
    
    plt.clf()
    rf_model.partial_plot(data = htest, cols = plot_varimp_df["variable"].tolist(), nbins=2, plot=True)
    show(plt)