I'd like to get the weight for the tree nodes from a saved (or unsaved) DecisionTreeClassificationModel
. However I can't find anything remotely resembling that.
How does the model actually perform the classification not knowing any of those. Below are the Params that are saved in the model:
{"class":"org.apache.spark.ml.classification.DecisionTreeClassificationModel"
"timestamp":1551207582648
"sparkVersion":"2.3.2"
"uid":"DecisionTreeClassifier_4ffc94d20f1ddb29f282"
"paramMap":{
"cacheNodeIds":false
"maxBins":32
"minInstancesPerNode":1
"predictionCol":"prediction"
"minInfoGain":0.0
"rawPredictionCol":"rawPrediction"
"featuresCol":"features"
"probabilityCol":"probability"
"checkpointInterval":10
"seed":956191873026065186
"impurity":"gini"
"maxMemoryInMB":256
"maxDepth":2
"labelCol":"indexed"
}
"numFeatures":1
"numClasses":2
}
By using treeWeights
:
treeWeights
Return the weights for each tree
New in version 1.5.0.
So
How does the model actually perform the classification not knowing any of those.
The weights are stored, just not as a part of the metadata. If you have model
from pyspark.ml.classification import RandomForestClassificationModel
model: RandomForestClassificationModel = ...
and save it to disk
path: str = ...
model.save(path)
you'll see that the writer creates treesMetadata
subdirectory. If you load the content (default writer uses Parquet):
import os
trees_metadata = spark.read.parquet(os.path.join(path, "treesMetadata"))
you'll see following structure:
trees_metadata.printSchema()
root
|-- treeID: integer (nullable = true)
|-- metadata: string (nullable = true)
|-- weights: double (nullable = true)
where weights
column contains the weight of tree identified by treeID
.
Similarly node data is stored in the data
subdirectory (see for example Extract and Visualize Model Trees from Sparklyr):
spark.read.parquet(os.path.join(path, "data")).printSchema()
root
|-- id: integer (nullable = true)
|-- prediction: double (nullable = true)
|-- impurity: double (nullable = true)
|-- impurityStats: array (nullable = true)
| |-- element: double (containsNull = true)
|-- gain: double (nullable = true)
|-- leftChild: integer (nullable = true)
|-- rightChild: integer (nullable = true)
|-- split: struct (nullable = true)
| |-- featureIndex: integer (nullable = true)
| |-- leftCategoriesOrThreshold: array (nullable = true)
| | |-- element: double (containsNull = true)
| |-- numCategories: integer (nullable = true)
Equivalent information (minus tree data and tree weights) is available for DecisionTreeClassificationModel
as well.