Search code examples
javatopic-modelingmallet

Different topic distributions for the same data with mallet topic modeling


I am using Mallet topic modeling and I have trained a model. Right after the training, I print the topic distribution for one of the documents of the training set and save it. Then, I try the same document as the test set and pass it through the same pipes and so on. But I get a completely different topic distribution for that. The highest ranked topic after the training which as a probability of around 0.54, has a probability of 0.000 when used as the test set. Here are my codes for the training and testing:

 public static ArrayList<Object> trainModel() throws IOException {

        String fileName = "E:\\Alltogether.txt";
        String stopwords = "E:\\stopwords-en.txt";
        // Begin by importing documents from text to feature sequences
        ArrayList<Pipe> pipeList = new ArrayList<Pipe>();

        // Pipes: lowercase, tokenize, remove stopwords, map to features
        pipeList.add(new CharSequenceLowercase());
        pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")));
        pipeList.add(new TokenSequenceRemoveStopwords(new File(stopwords), "UTF-8", false, false, false));
        pipeList.add(new TokenSequenceRemoveNonAlpha(true));
        pipeList.add(new TokenSequence2FeatureSequence());
        InstanceList instances = new InstanceList(new SerialPipes(pipeList));

        Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8");
        instances.addThruPipe(new CsvIterator(fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
                3, 2, 1)); // data, label, name fields

        int numTopics = 75;
        ParallelTopicModel model = new ParallelTopicModel(numTopics, 5.0, 0.01);

        model.setOptimizeInterval(20);
        model.addInstances(instances);
        model.setNumThreads(2);
        model.setNumIterations(2000);
        model.estimate();

        ArrayList<Object> results = new ArrayList<>();
        results.add(model);
        results.add(instances);

        Alphabet dataAlphabet = instances.getDataAlphabet();

        FeatureSequence tokens = (FeatureSequence) model.getData().get(66).instance.getData();
        LabelSequence topics = model.getData().get(66).topicSequence;

        Formatter out = new Formatter(new StringBuilder(), Locale.US);
        for (int position = 0; position < tokens.getLength(); position++) {
            out.format("%s-%d ", dataAlphabet.lookupObject(tokens.getIndexAtPosition(position)), topics.getIndexAtPosition(position));
        }
        System.out.println(out);

        // Estimate the topic distribution of the 66th instance,
        //  given the current Gibbs state.
        double[] topicDistribution = model.getTopicProbabilities(66);

        ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();

        for (int topic = 0; topic < numTopics; topic++) {
            Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();

            out = new Formatter(new StringBuilder(), Locale.US);
            out.format("%d\t%.3f\t", topic, topicDistribution[topic]);
            int rank = 0;
            while (iterator.hasNext() && rank < 10) {
                IDSorter idCountPair = iterator.next();
                out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
                rank++;
            }
            System.out.println(out);
        }

        return results;
    }

And here is the testing part:

private static void testModel(ArrayList<Object> results, String testDir) {


    ParallelTopicModel model = (ParallelTopicModel) results.get(0);
    InstanceList allTrainInstances = (InstanceList) results.get(1);

    String stopwords = "E:\\stopwords-en.txt";

    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();

    pipeList.add(new CharSequenceLowercase());
    pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}")));
    pipeList.add(new TokenSequenceRemoveStopwords(new File(stopwords), "UTF-8", false, false, false));
    pipeList.add(new TokenSequenceRemoveNonAlpha(true));
    pipeList.add(new TokenSequence2FeatureSequence());

    InstanceList instances = new InstanceList(new SerialPipes(pipeList));

    Reader fileReader = null;
    try {
        fileReader = new InputStreamReader(new FileInputStream(new File(testDir)), "UTF-8");
    } catch (UnsupportedEncodingException e) {
        e.printStackTrace();
    } catch (FileNotFoundException e) {
        e.printStackTrace();
    }
    instances.addThruPipe(new CsvIterator(fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
            3, 2, 1)); // data, label, name fields

    TopicInferencer inferencer = model.getInferencer();
    inferencer.setRandomSeed(1);

    double[] testProbabilities = inferencer.getSampledDistribution(instances.get(0), 10, 1, 5);
    System.out.println(testProbabilities);
    int index = getMaximum(testProbabilities);

    ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords();

    Alphabet dataAlphabet = allTrainInstances.getDataAlphabet();
    Formatter out = new Formatter(new StringBuilder(), Locale.US);

    for (int topic = 0; topic < 75; topic++) {
        Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();

        out = new Formatter(new StringBuilder(), Locale.US);
        out.format("%d\t%.3f\t", topic, testProbabilities[topic]);
        int rank = 0;
        while (iterator.hasNext() && rank < 10) {
            IDSorter idCountPair = iterator.next();
            out.format("%s (%.0f) ", dataAlphabet.lookupObject(idCountPair.getID()), idCountPair.getWeight());
            rank++;
        }
        System.out.println(out);
    }

}

In the line

    double[] testProbabilities = inferencer.getSampledDistribution(instances.get(0), 10, 1, 5);

I can simply see that the probabilities are different. In the meantime, I have tried with different files, but I always get the same topic as the highest ranked topic. Any help is appreciated.


Solution

  • I answer my own question for later uses if somebody faces the same problem. In the documents of MALLET it's said that you should use the same pipes for training and testing. I realized that "new"ing the same pipes as done for the training step does NOT mean using the same pipes. You should save the pipes when you train your model and re-load the them when testing. I took the sample code for this question and it works now.