Search code examples
pythonnumpyscikit-learnclassificationmulticlass-classification

Using sklearn for hierarchical classification


I was wondering if hierarchical classifications are supported by the sciki-learn library. I am dealing with the 3 classes divided by 6 subclasses each, such as:

import numpy as np
from sklearn.tree import DecisionTreeClassifier

X = np.random.randn(5, 1)

number, rows, cols = 5, 3, 6
y = np.zeros((number, rows, cols), dtype=int)
for n in range(number):
    for row in range(rows):
        col = np.random.randint(cols)
        y[n, row, col] = 1

tree = DecisionTreeClassifier()
tree.fit(X, y)

but find error:

ValueError: Found array with dim 3. DecisionTreeClassifier expected <= 2.

Solution

  • You can use hiclass.

    pip install hiclass
    

    Train the model:

    from sklearn.tree import DecisionTreeClassifier
    from hiclass.MultiLabelLocalClassifierPerNode import MultiLabelLocalClassifierPerNode
    
    tree = DecisionTreeClassifier()
    classifier = MultiLabelLocalClassifierPerNode(local_classifier=tree)
    classifier.fit(X, y)
    

    Test and measure precision:

    hiclass.metrics import precision
    
    predictions = classifier.predict(y_test)
    p = precision(y_test, predictions)
    print(p)
    

    Refer to the hiclass paper for complementary information.