Search code examples
machine-learningnlplstmlanguage-modelpenn-treebank

how to learn language model?


  1. I'm trying to train a language model with LSTM based on Penn Treebank (PTB) corpus.

    I was thinking that I should simply train with every bigram in the corpus so that it could predict the next word given previous words, but then it wouldn't be able to predict next word based on multiple preceding words.

    So what exactly is it to train a language model?

  2. In my current implementation, I have batch size=20 and the vocabulary size is 10000, so I have 20 resulting matrices of 10k entries (parameters?) and the loss is calculated by making comparison to 20 ground-truth matrices of 10k entries, where only the index for actual next word is 1 and other entries are zero. Is this a right implementation? I'm getting perplexity of around 2 that hardly changes over iterations, which is definitely not in a right range of what it usually is, say around 100.


Solution

  • So what exactly is it to train a language model?

    I think you don't need to train with every bigram in the corpus. Just use a sequence to sequence model, and when you predict the next word given previous words you just choose the one with the highest probability.

    so I have 20 resulting matrices of 10k entries (parameters?)

    Yes, per step of decoding.

    Is this a right implementation? I'm getting perplexity of around 2 that hardly changes over iterations, which is definitely not in a right range of what it usually is, say around 100.

    You can first read some open source code as a reference. For instance: word-rnn-tensorflow and char-rnn-tensorflow. The perplexity is at large -log(1/10000) which is around 9 per word(which means the model is not trained at all and selects the words totally randomly, as the model being tuned the complexity will decrease, so 2 is reasonable). I think 100 in your statement may mean the complexity per sentence.

    For example, if tf.contrib.seq2seq.sequence_loss is employed to calculate the complexity, the result will be less than 10 if you set both average_across_timesteps and average_across_batch to be True as default, but if you set the average_across_timesteps to be False and the average length of the sequence is about 10, it will be about 100.