Search code examples
pythonmachine-learningknnmulticlass-classification

Why do I get 0 neighbors in my KNN model?


I'm doing a multi-class audio classification, where I have

1) paths to each .wav file which records an individual word 2) a vector of MFCCs for each path 3) a label (the actual word) for each path in the training set

Apparently, my algorithm is not recognizing any neighbor. Any reason?

My sets look like:

X_train, X_test, y_train, y_test = train_test_split(new_X, y, test_size=0.4, random_state=5)
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

(56894, 99) (56894,) (37930, 99) (37930,)

My model is:

k_range = list(range(1))
scores = []
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    scores.append(metrics.accuracy_score(y_test, y_pred))

then the error says:

ValueError                                Traceback (most recent call last)
<ipython-input-12-d67236a4ac61> in <module>()
      3 for k in k_range:
      4     knn = KNeighborsClassifier(n_neighbors=k)
----> 5     knn.fit(X_train, y_train)
      6     y_pred = knn.predict(X_test)
      7     scores.append(metrics.accuracy_score(y_test, y_pred))

1 frames
/usr/local/lib/python3.6/dist-packages/sklearn/neighbors/base.py in fit(self, X, y)
    915             self._y = self._y.ravel()
    916 
--> 917         return self._fit(X)
    918 
    919 

/usr/local/lib/python3.6/dist-packages/sklearn/neighbors/base.py in _fit(self, X)
    266                 raise ValueError(
    267                     "Expected n_neighbors > 0. Got %d" %
--> 268                     self.n_neighbors
    269                 )
    270             else:

ValueError: Expected n_neighbors > 0. Got 0

Solution

  • It's not failing to recognise neighbours, it's complaining that you're using n_neighbours = 0

    Your loop uses k_range which is [0,1] so on the first iteration of the loop, the failing call evaluates to:

    knn = KNeighborsClassifier(n_neighbors=0)
    

    You need to change the range in:

    k_range = list(range(1))
    

    to not include zero.