Search code examples
javawekatext-classificationcategorization

text classifier with weka: how to correctly train a classifier issue


I'm trying to build a text classifier using Weka, but the probabilities with distributionForInstance of the classes are 1.0 in one and 0.0 in all other cases, so classifyInstance always returns the same class as prediction. Something in the training doesn't work correctly.

ARFF training

@relation test1

@attribute tweetmsg    String
@attribute classValues {politica,sport,musicatvcinema,infogeneriche,fattidelgiorno,statopersonale,checkin,conversazione}

@DATA

"Renzi Berlusconi Salvini Bersani",politica
"Allegri insulta la terna arbitrale",sport
"Bravo Garcia",sport

Training methods

public void trainClassifier(final String INPUT_FILENAME) throws Exception
{
    getTrainingDataset(INPUT_FILENAME);

    //trainingInstances consists of feature vector of every input

    for(Instance currentInstance : inputDataset)
    {           
        Instance currentFeatureVector = extractFeature(currentInstance);

        currentFeatureVector.setDataset(trainingInstances);
        trainingInstances.add(currentFeatureVector);                
    }

    classifier = new NaiveBayes();

    try {
        //classifier training code
        classifier.buildClassifier(trainingInstances);

        //storing the trained classifier to a file for future use
        weka.core.SerializationHelper.write("NaiveBayes.model",classifier);
    } catch (Exception ex) {
        System.out.println("Exception in training the classifier."+ex);
    }
}

private Instance extractFeature(Instance inputInstance) throws Exception
{       
    String tweet = inputInstance.stringValue(0);
    StringTokenizer defaultTokenizer = new StringTokenizer(tweet);
    List<String> tokens=new ArrayList<String>();
    while (defaultTokenizer.hasMoreTokens())
    {
        String t= defaultTokenizer.nextToken();
        tokens.add(t);
    }

    Iterator<String> a = tokens.iterator();
    while(a.hasNext())
    {
                String token=(String) a.next();
                String word = token.replaceAll("#","");
                if(featureWords.contains(word))
                {                                              
                    double cont=featureMap.get(featureWords.indexOf(word))+1;
                    featureMap.put(featureWords.indexOf(word),cont);
                }
                else{
                    featureWords.add(word);
                    featureMap.put(featureWords.indexOf(word), 1.0);
                }

    }
    attributeList.clear();
    for(String featureWord : featureWords)
    {
        attributeList.add(new Attribute(featureWord));   
    }
    attributeList.add(new Attribute("Class", classValues));
    int indices[] = new int[featureMap.size()+1];
    double values[] = new double[featureMap.size()+1];
    int i=0;
    for(Map.Entry<Integer,Double> entry : featureMap.entrySet())
    {
        indices[i] = entry.getKey();
        values[i] = entry.getValue();
        i++;
    }
    indices[i] = featureWords.size();
    values[i] = (double)classValues.indexOf(inputInstance.stringValue(1));
    trainingInstances = createInstances("TRAINING_INSTANCES");

    return new SparseInstance(1.0,values,indices,1000000);
}


private void getTrainingDataset(final String INPUT_FILENAME)
{
    try{
        ArffLoader trainingLoader = new ArffLoader();
        trainingLoader.setSource(new File(INPUT_FILENAME));
        inputDataset = trainingLoader.getDataSet();
    }catch(IOException ex)
    {
        System.out.println("Exception in getTrainingDataset Method");
    }
    System.out.println("dataset "+inputDataset.numAttributes());
}

private Instances createInstances(final String INSTANCES_NAME)
{
    //create an Instances object with initial capacity as zero 
    Instances instances = new Instances(INSTANCES_NAME,attributeList,0);
    //sets the class index as the last attribute
    instances.setClassIndex(instances.numAttributes()-1);

    return instances;
}

public static void main(String[] args) throws Exception
{
      Classificatore wekaTutorial = new Classificatore();
      wekaTutorial.trainClassifier("training_set_prova_tent.arff");
      wekaTutorial.testClassifier("testing.arff");
}

public Classificatore()
{
    attributeList = new ArrayList<Attribute>();
    initialize();
}    

private void initialize()
{

    featureWords= new ArrayList<String>(); 

    featureMap = new TreeMap<>();

    classValues= new ArrayList<String>();
    classValues.add("politica");
    classValues.add("sport");
    classValues.add("musicatvcinema");
    classValues.add("infogeneriche");
    classValues.add("fattidelgiorno");
    classValues.add("statopersonale");
    classValues.add("checkin");
    classValues.add("conversazione");
}

TESTING METHODS

public void testClassifier(final String INPUT_FILENAME) throws Exception
{
    getTrainingDataset(INPUT_FILENAME);

    //trainingInstances consists of feature vector of every input
    Instances testingInstances = createInstances("TESTING_INSTANCES");

    for(Instance currentInstance : inputDataset)
    {

        //extractFeature method returns the feature vector for the current input
        Instance currentFeatureVector = extractFeature(currentInstance);
        //Make the currentFeatureVector to be added to the trainingInstances
        currentFeatureVector.setDataset(testingInstances);
        testingInstances.add(currentFeatureVector);

    }


    try {
        //Classifier deserialization
        classifier = (Classifier) weka.core.SerializationHelper.read("NaiveBayes.model");

        //classifier testing code
        for(Instance testInstance : testingInstances)
        {

            double score = classifier.classifyInstance(testInstance);
            double[] vv= classifier.distributionForInstance(testInstance);
            for(int k=0;k<vv.length;k++){
            System.out.println("distribution "+vv[k]); //this are the probabilities of the classes and as result i get 1.0 in one and 0.0 in all the others
            }
            System.out.println(testingInstances.attribute("Class").value((int)score));
        }
    } catch (Exception ex) {
        System.out.println("Exception in testing the classifier."+ex);
    }
}

I want to create a text classifier for short messages, this code is based on this tutorial http://preciselyconcise.com/apis_and_installations/training_a_weka_classifier_in_java.php . The problem is that the classifier predict the wrong class for almost every message in the testing.arff because the probabilities of the classes are not correct. The training_set_prova_tent.arff has the same number of messages per class. The example i'm following use a featureWords.dat and associate 1.0 to the word if it is present in a message instead I want to create my own dictionary with the words present in the training_set_prova_tent plus the words present in testing and associate to every word the number of occurrences .

P.S I know that this is exactly what can i do with the filter StringToWordVector but I haven't found any example that exaplain how to use this filter with two file: one for the training set and one for the test set. So it seems easier to adapt the code I found.

Thank you very much


Solution

  • It seems like you changed the code from the website you referenced in some crucial points, but not in a good way. I'll try to draft what you're trying to do and what mistakes I've found.

    What you (probably) wanted to do in extractFeature is

    • Split each tweet into words (tokenize)
    • Count the number of occurrences of these words
    • Create a feature vector representing these word counts plus the class

    What you've overlooked in that method is

    1. You never reset your featureMap. The line

      Map<Integer,Double> featureMap = new TreeMap<>();
      

      originally was at the beginning extractFeatures, but you moved it to initialize. That means that you always add up the word counts, but never reset them. For each new tweet, your word count also includes the word count of all previous tweets. I'm sure that is not what you wanted.

    2. You don't initialize featureWords with the words you want as features. Yes, you create an empty list, but you fill it iteratively with each tweet. The original code initialized it once in the initialize method and it never changed after that. There are two problems with that:

      • With each new tweet, new features (words) get added, so your feature vector grows with each tweet. That wouldn't be such a big problem (SparseInstance), but that means that
      • Your class attribute is always in another place. These two lines work for the original code, because featureWords.size() is basically a constant, but in your code the class label will be at index 5, then 8, then 12, and so on, but it must be the same for every instance.
      indices[i] = featureWords.size();
      values[i] = (double) classValues.indexOf(inputInstance.stringValue(1));
      
    3. This also manifests itself in the fact that you build a new attributeList with each new tweet, instead of only once in initialize, which is bad for already explained reasons.

    There may be more stuff, but - as it is - your code is rather unfixable. What you want is much closer to the tutorial source code which you modified than your version.

    Also, you should look into StringToWordVector because it seems like this is exactly what you want to do:

    Converts String attributes into a set of attributes representing word occurrence (depending on the tokenizer) information from the text contained in the strings. The set of words (attributes) is determined by the first batch filtered (typically training data).