Search code examples
apache-sparkpysparkapache-spark-ml

How to get all parameters of estimator in PySpark


I have a RandomForestRegressor, GBTRegressor and I'd like to get all parameters of them. The only way I found it could be done with several get methods like:

from pyspark.ml.regression import RandomForestRegressor, GBTRegressor
est = RandomForestRegressor()
est.getMaxDepth()
est.getSeed()

But RandomForestRegressor and GBTRegressor have different parameters so it's not a good idea to hardcore all that methods. A workaround could be something like this:

get_methods = [method for method in dir(est) if method.startswith('get')]

params_est = {}
for method in get_methods:
    try:
        key = method[3:]
        params_est[key] = getattr(est, method)() 
    except TypeError:
        pass

Then output will be like this:

params_est

{'CacheNodeIds': False,
 'CheckpointInterval': 10,
 'FeatureSubsetStrategy': 'auto',
 'FeaturesCol': 'features',
 'Impurity': 'variance',
 'LabelCol': 'label',
 'MaxBins': 32,
 'MaxDepth': 5,
 'MaxMemoryInMB': 256,
 'MinInfoGain': 0.0,
 'MinInstancesPerNode': 1,
 'NumTrees': 20,
 'PredictionCol': 'prediction',
 'Seed': None,
 'SubsamplingRate': 1.0}

But I think there should be a better way to do that.


Solution

  • extractParamMap can be used to get all params from every estimator, for example:

    >>> est = RandomForestRegressor()
    >>> {param[0].name: param[1] for param in est.extractParamMap().items()}
    {'numTrees': 20, 'cacheNodeIds': False, 'impurity': 'variance', 'predictionCol': 'prediction', 'labelCol': 'label', 'featuresCol': 'features', 'minInstancesPerNode': 1, 'seed': -5851613654371098793, 'maxDepth': 5, 'featureSubsetStrategy': 'auto', 'minInfoGain': 0.0, 'checkpointInterval': 10, 'subsamplingRate': 1.0, 'maxMemoryInMB': 256, 'maxBins': 32}
    >>> est = GBTRegressor()
    >>> {param[0].name: param[1] for param in est.extractParamMap().items()}
    {'cacheNodeIds': False, 'impurity': 'variance', 'predictionCol': 'prediction', 'labelCol': 'label', 'featuresCol': 'features', 'stepSize': 0.1, 'minInstancesPerNode': 1, 'seed': -6363326153609583521, 'maxDepth': 5, 'maxIter': 20, 'minInfoGain': 0.0, 'checkpointInterval': 10, 'subsamplingRate': 1.0, 'maxMemoryInMB': 256, 'lossType': 'squared', 'maxBins': 32}