Search code examples
pythonscikit-learnmultilabel-classification

The number of classes has to be greater than one; got 1 class on the MultiOutputClassifier from sklearn


I get the error of having fewer than two classes. This is the important part of the code:

texts = [element['text'] for element in train_data]
labels = [element['labels'] if element['labels'] else ['Free'] for element in train_data]

mlb = MultiLabelBinarizer(classes=G.nodes)
y_bin = mlb.fit_transform(labels)

#texts was turned in X_reduced cause i did some other changes there
X_train, X_test, y_train, y_test = train_test_split(X_reduced, y_bin, test_size=0.2, random_state=42)

multi_label_classifier = MultiOutputClassifier(SVC(kernel='linear', probability=True))
y_train=y_train.astype(np.uint8)  #i saw this in a precedent post but it didn't worked
multi_label_classifier.fit(X_train, y_train)

y_train is a matrix 400X31, as you can see i even added a new class "Free" to be sure i didn't had any row of the matrix with only zeros.

To be even more sure i did these tests.

len(np.unique(y_train))

result --> 2

def atleast_one(matrix):
    for row in matrix:
        if 1 in row:
            continue
        else:
            print(row)
            return False
    # every row contains a 1
    return True
atleast_one(y_train)

result --> True

np.any(np.all(y_train == 0, axis=1))

result --> False

After all of this i can't still do the fit for this error and i don't understand why.

This is the error:

ValueError                                Traceback (most recent call last)
<ipython-input-14-cb7366c3ac07> in <cell line: 55>()
     53 multi_label_classifier = MultiOutputClassifier(SVC(kernel='linear', probability=True))
     54 y_train=y_train.astype(np.uint8)
---> 55 multi_label_classifier.fit(X_train, y_train)
     56 
     57 def concatenate_row_elements(matrix):

/usr/local/lib/python3.10/dist-packages/sklearn/svm/_base.py in _validate_targets(self, y)
    747         self.class_weight_ = compute_class_weight(self.class_weight, classes=cls, y=y_)
    748         if len(cls) < 2:
--> 749             raise ValueError(
    750                 "The number of classes has to be greater than one; got %d class"
    751                 % len(cls)
ValueError: The number of classes has to be greater than one; got 1 class

P.s. this is my first question, sorry if i am doing some mistakes


Solution

  • You appear to have some column (after multilabel binarizing and splitting train/test) with all zeros or all ones. What you checked so far concerns the rows, but the SVM is complaining about one of the target columns.