Search code examples
pythonscikit-learnxor

Warning message in scikit-learn


I wrote a very simple scikit-learn decision tree to implement XOR:

from sklearn import tree
X = [[0, 0], [1, 1], [0, 1], [1, 0]]
Y = [0, 0, 1, 1]
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)

print(clf.predict([0,1]))
print(clf.predict([0,0]))
print(clf.predict([1,1]))
print(clf.predict([1,0]))

predict part generates some warning like this:

DeprecationWarning: Passing 1d arrays as data is deprecated in 0.17 and will raise ValueError in 0.19. Reshape your data either using X.reshape(-1, 1) if your data has a single feature or X.reshape(1, -1) if it contains a single sample.

What needs to change and why?


Solution

  • The input to clf.predict should be a 2D array. Thus, instead of writing

    print(clf.predict([0,1]))

    you need to write

    print(clf.predict([[0,1]]))