I am using sklearn.tree.DecisionTreeClassifier
to train 3-class classification problem.
The number of records in 3 classes are given below:
A: 122038
B: 43626
C: 6678
When I train the classifier model it fails to learn the class - C
. Though efficiency comes out to be 65-70% but it completely ignores the class C.
Then I came to know about class_weight
parameter but I am not sure how to use it in multiclass setting.
Here is my code: ( I used balanced
but it gave more poor accuracy)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
clf = tree.DecisionTreeClassifier(criterion="gini", max_depth=3, random_state=1,class_weight='balanced')
clf = clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)
How can I use weights with proportion to class distributions.
Secondly, is there any better way to address this Imbalance class problem to increase accuracy.?
You can also pass a dictionary of values to the class_weight argument in order to set your own weights. For example to weight class A half as much you could do:
class_weight={
'A': 0.5,
'B': 1.0,
'C': 1.0
}
By doing class_weight='balanced' it automatically sets the weights inversely proportional to class frequencies.
More information can be found in the docs under the class_weight argument: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
It is usually to be expected that balancing the classes would reduce accuracy. This is why accuracy is often considered to be a poor metric for imbalanced data sets.
You can try the Balanced Accuracy metric that sklearn includes to start, but there are many other potential metrics to try which will depend on what your ultimate goal is.
https://scikit-learn.org/stable/modules/model_evaluation.html
If you are not familiar with the 'confusion matrix' and its related values like precision and recall then I would start your research there.
https://en.wikipedia.org/wiki/Precision_and_recall
https://en.wikipedia.org/wiki/Confusion_matrix
https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html