Search code examples
apache-sparkpysparkapache-spark-ml

Spark ML: How does DecisionTreeClassificatonModel know about the tree weights?


I'd like to get the weight for the tree nodes from a saved (or unsaved) DecisionTreeClassificationModel. However I can't find anything remotely resembling that.

How does the model actually perform the classification not knowing any of those. Below are the Params that are saved in the model:

{"class":"org.apache.spark.ml.classification.DecisionTreeClassificationModel"
"timestamp":1551207582648
"sparkVersion":"2.3.2"
"uid":"DecisionTreeClassifier_4ffc94d20f1ddb29f282"
"paramMap":{
"cacheNodeIds":false
"maxBins":32
"minInstancesPerNode":1
"predictionCol":"prediction"
"minInfoGain":0.0
"rawPredictionCol":"rawPrediction"
"featuresCol":"features"
"probabilityCol":"probability"
"checkpointInterval":10
"seed":956191873026065186
"impurity":"gini"
"maxMemoryInMB":256
"maxDepth":2
"labelCol":"indexed"
}
"numFeatures":1
"numClasses":2
}

Solution

  • By using treeWeights:

    treeWeights

    Return the weights for each tree

    New in version 1.5.0.

    So

    How does the model actually perform the classification not knowing any of those.

    The weights are stored, just not as a part of the metadata. If you have model

    from pyspark.ml.classification import RandomForestClassificationModel
    
    model: RandomForestClassificationModel = ...
    

    and save it to disk

    path: str = ...
    
    model.save(path)
    

    you'll see that the writer creates treesMetadata subdirectory. If you load the content (default writer uses Parquet):

    import os
    
    trees_metadata = spark.read.parquet(os.path.join(path, "treesMetadata"))
    

    you'll see following structure:

    trees_metadata.printSchema()
    
    root
     |-- treeID: integer (nullable = true)
     |-- metadata: string (nullable = true)
     |-- weights: double (nullable = true)
    

    where weights column contains the weight of tree identified by treeID.

    Similarly node data is stored in the data subdirectory (see for example Extract and Visualize Model Trees from Sparklyr):

    spark.read.parquet(os.path.join(path, "data")).printSchema()     
    
    root
     |-- id: integer (nullable = true)
     |-- prediction: double (nullable = true)
     |-- impurity: double (nullable = true)
     |-- impurityStats: array (nullable = true)
     |    |-- element: double (containsNull = true)
     |-- gain: double (nullable = true)
     |-- leftChild: integer (nullable = true)
     |-- rightChild: integer (nullable = true)
     |-- split: struct (nullable = true)
     |    |-- featureIndex: integer (nullable = true)
     |    |-- leftCategoriesOrThreshold: array (nullable = true)
     |    |    |-- element: double (containsNull = true)
     |    |-- numCategories: integer (nullable = true)
    

    Equivalent information (minus tree data and tree weights) is available for DecisionTreeClassificationModel as well.