Search code examples
pythonmachine-learningscikit-learntext-classification

SciKit-learn's 'predict' function giving output in wrong format


I am new to scikit and so playing around with it.

Background about the problem: I am trying to play with 'Byte the correct apple' competition on hackerRank. In which we are given two files one containing the text of apple the company and one for apple the fruit. Now we must learn from it and then make prediction on a new text.

Though the code runs but my problems are: - As 'line' (in the code below) is a single input I should get single digit output either zero or one. But I am getting an array as an output. - Am I even close to learning anything using the code below?

import numpy as np

from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer


from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import SGDClassifier
from sklearn import svm
from sklearn.svm import LinearSVC

from sklearn.pipeline import Pipeline

appleComputers = [];
appleFruits = [];
labels = [];

with open('apple-computers.txt','r') as f:
    for line in f:
        appleComputers.append(line)
        labels.append(1);

with open('apple-fruit.txt','r') as f:
    for line in f:
        appleFruits.append(line)
        labels.append(0);

text = appleComputers + appleFruits;
labels = np.asarray(labels)

#text_clf = Pipeline([('vect', CountVectorizer()),('tfidf', TfidfTransformer()),('clf', MultinomialNB()),])
text_clf = Pipeline([('vect', CountVectorizer()),('tfidf', TfidfTransformer()),('clf', LinearSVC(loss='hinge', penalty='l2')),])

text_clf = text_clf.fit(text, labels)


line = 'I am talking about apple the fruit we eat.'
line = 'I am talking about the product apple computer by Steve Jobs'
predicted = text_clf.predict(line);
print predicted

Solution

  • I found the answer by myself.

    For

    predicted = text_clf.predict(line);
    

    'line' should be a list and not a string as it was for the 'fit' function.

    i.e. Replace

    line = 'I am talking about the product apple computer by Steve Jobs'
    

    by

    line = [];    
    line.append('I am talking about apple the fruit we eat.');
    

    or @jme suggested we can use

    text_clf.predict([line])