Search code examples
pythonmachine-learningscikit-learnroc

Plot multi-class ROC curve for DecisionTreeClassifier


I was trying to plot ROC curve with classifiers other than svm.SVC which is provided in the documentation. My code works good for svm.SVC; however, after I switched to KNeighborsClassifier, MultinomialNB, and DecisionTreeClassifier, the system keeps telling me check_consistent_length(y_true, y_score)andFound input variables with inconsistent numbers of samples: [26632, 53264] My CSV file looks like this

And here is my code

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
import sys
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.tree import DecisionTreeClassifier
# Import some data to play with
df = pd.read_csv("E:\\autodesk\\Hourly and weather categorized2.csv")
X =df[['TTI','Max TemperatureF','Mean TemperatureF','Min TemperatureF',' Min Humidity']].values
y = df['TTI_Category'].as_matrix()
y=y.reshape(-1,1)
# Binarize the output
y = label_binarize(y, classes=['Good','Bad'])
n_classes = y.shape[1]

# shuffle and split training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
                                                    random_state=0)

# Learn to predict each class against the other
classifier = OneVsRestClassifier(DecisionTreeClassifier(random_state=0))
y_score = classifier.fit(X_train, y_train).predict_proba(X_test)

# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()

roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
plt.figure()
lw = 1
plt.plot(fpr[0], tpr[0], color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[0])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()

I'm suspecting that the error occurs at this line fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]),but I'm a beginner to this ROC curve, so could someone kindly guide me through this traceback. Thanks a lot for your time and help.Here is another question regarding ROC curve from me By the way here is the whole traceback. Hopefully my explanation is clear enough. `

Traceback (most recent call last):

  File "<ipython-input-1-16eb0db9d4d9>", line 1, in <module>
    runfile('C:/Users/Think/Desktop/Python Practice/ROC with decision tree.py', wdir='C:/Users/Think/Desktop/Python Practice')

  File "C:\Users\Think\Anaconda2\lib\site-packages\spyder\utils\site\sitecustomize.py", line 880, in runfile
    execfile(filename, namespace)

  File "C:\Users\Think\Anaconda2\lib\site-packages\spyder\utils\site\sitecustomize.py", line 87, in execfile
    exec(compile(scripttext, filename, 'exec'), glob, loc)

  File "C:/Users/Think/Desktop/Python Practice/ROC with decision tree.py", line 47, in <module>
    fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())

  File "C:\Users\Think\Anaconda2\lib\site-packages\sklearn\metrics\ranking.py", line 510, in roc_curve
    y_true, y_score, pos_label=pos_label, sample_weight=sample_weight)

  File "C:\Users\Think\Anaconda2\lib\site-packages\sklearn\metrics\ranking.py", line 302, in _binary_clf_curve
    check_consistent_length(y_true, y_score)

  File "C:\Users\Think\Anaconda2\lib\site-packages\sklearn\utils\validation.py", line 173, in check_consistent_length
    " samples: %r" % [int(l) for l in lengths])

ValueError: Found input variables with inconsistent numbers of samples: [26632, 53264]

Solution

  • The problem is been solved by adding this line to the original code y_resampled = label_binarize(y_resampled, classes=['Good','Bad','Ok'])