Search code examples
javawekatf-idftfidfvectorizer

Calculate TF-IDF in WEKA API for single document to predict classification


For some reason I am using the WEKA API...

I have generated tf-idf scores for a set of documents,

StringToWordVector filter = new StringToWordVector();
filter.setInputFormat(data);
filter.setIDFTransform(true);
filter.setStopwordsHandler(new StopWordsHandlerEN());//just a simple handler for stop words I created
filter.setLowerCaseTokens(true);
filter.setStemmer(new MyStemmer());//a stemmer I created
filter.setWordsToKeep(words2keep);
Instances result = Filter.useFilter(data, filter);

then split them into train and test subsets, did training, testing and all that...

Once I had a trained ready-to-go model for classification I wanted to create a plain API that would classify any incoming document. But the thing is the new tf-idf scores need to be calculated based on the tf-idf vector and the words of the starting set of documents, right? In other words, if I am not mistaken, I need to load a counterpart of the scikit-learn's tfidfvectorizer.

I cannot find anything like it in WEKA... Is there?..


Solution

  • The StringToWordVector filter uses the weka.core.DictionaryBuilder class under the hood for the TF/IDF computation.

    As long as you create a weka.core.Instance object with the text that you want to have converted, you can do that using the builder's vectorizeInstance(Instance) method.

    Edit 1:

    Below is an example based on your code (but with Weka classes), which shows how to either use the filter or the DictionaryBuilder for the TF/IDF transformation. Both get serialized, deserialized and re-used as well to demonstrate that these classes are serializable:

    import weka.core.DictionaryBuilder;
    import weka.core.Instance;
    import weka.core.Instances;
    import weka.core.SerializationHelper;
    import weka.core.converters.ConverterUtils;
    import weka.core.stemmers.LovinsStemmer;
    import weka.core.stopwords.Rainbow;
    import weka.filters.Filter;
    import weka.filters.unsupervised.attribute.StringToWordVector;
    
    public class TFIDF {
    
      // just exposes the internal DictionaryBuilder member
      public static class StringToWordVectorExposed
        extends StringToWordVector {
    
        public DictionaryBuilder getDictionary() {
          return m_dictionaryBuilder;
        }
      }
    
      public static void main(String[] args) throws Exception {
        // load data
        Instances train = ConverterUtils.DataSource.read("/some/where/train.arff");
        train.setClassIndex(train.numAttributes() - 1);
        Instances test = ConverterUtils.DataSource.read("/some/where/test.arff");
        test.setClassIndex(test.numAttributes() - 1);
        // init filter
        StringToWordVectorExposed filter = new StringToWordVectorExposed();
        int words2keep = 100;
        filter.setInputFormat(train);
        filter.setIDFTransform(true);
        filter.setStopwordsHandler(new Rainbow());
        filter.setLowerCaseTokens(true);
        filter.setStemmer(new LovinsStemmer());
        filter.setWordsToKeep(words2keep);
        filter.setInputFormat(train);
        Instances trainFiltered = Filter.useFilter(train, filter);
        DictionaryBuilder builder = filter.getDictionary();
        // apply filter/dictionary
        Instances testFiltered = Filter.useFilter(test, filter);
        System.out.println(testFiltered.instance(0));
        Instance tfidf = builder.vectorizeInstance(test.instance(0));
        System.out.println(tfidf);
        // serialize
        SerializationHelper.write("/some/where/filter.ser", filter);
        SerializationHelper.write("/some/where/dictionary.ser", filter.getDictionary());
        // deserialize
        StringToWordVectorExposed filter2 = (StringToWordVectorExposed) SerializationHelper.read("/some/where/filter.ser");
        DictionaryBuilder builder2 = (DictionaryBuilder) SerializationHelper.read("/some/where/dictionary.ser");
        // re-apply
        testFiltered = Filter.useFilter(test, filter2);
        System.out.println(testFiltered.instance(0));
        tfidf = builder2.vectorizeInstance(test.instance(0));
        System.out.println(tfidf);
      }
    }