Search code examples
pythonmachine-learningscikit-learndecision-tree

How to set class weights in DecisionTreeClassifier for multi-class setting


I am using sklearn.tree.DecisionTreeClassifier to train 3-class classification problem.

The number of records in 3 classes are given below:

A: 122038
B: 43626
C: 6678

When I train the classifier model it fails to learn the class - C. Though efficiency comes out to be 65-70% but it completely ignores the class C.

Then I came to know about class_weight parameter but I am not sure how to use it in multiclass setting.

Here is my code: ( I used balanced but it gave more poor accuracy)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
clf = tree.DecisionTreeClassifier(criterion="gini", max_depth=3, random_state=1,class_weight='balanced')
clf = clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)

How can I use weights with proportion to class distributions.

Secondly, is there any better way to address this Imbalance class problem to increase accuracy.?


Solution

  • You can also pass a dictionary of values to the class_weight argument in order to set your own weights. For example to weight class A half as much you could do:

    class_weight={
        'A': 0.5,
        'B': 1.0,
        'C': 1.0
    }
    

    By doing class_weight='balanced' it automatically sets the weights inversely proportional to class frequencies.

    More information can be found in the docs under the class_weight argument: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

    It is usually to be expected that balancing the classes would reduce accuracy. This is why accuracy is often considered to be a poor metric for imbalanced data sets.

    You can try the Balanced Accuracy metric that sklearn includes to start, but there are many other potential metrics to try which will depend on what your ultimate goal is.

    https://scikit-learn.org/stable/modules/model_evaluation.html

    If you are not familiar with the 'confusion matrix' and its related values like precision and recall then I would start your research there.

    https://en.wikipedia.org/wiki/Precision_and_recall

    https://en.wikipedia.org/wiki/Confusion_matrix

    https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html