Search code examples
javalstmdl4jnd4j

What is the output of the dl4j lstm neural network?


I am studying a text generation example https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/charmodelling/generatetext/GenerateTxtCharCompGraphModel.java. The output of lstm network is a probability distribution, as I understand it, this is an double array, where each value shows the probability of the character corresponding to the index in the array. So I cannot understand the following code where we get the character index from the distribution:

/** Given a probability distribution over discrete classes, sample from the distribution
 * and return the generated class index.
 * @param distribution Probability distribution over classes. Must sum to 1.0
 */
static int sampleFromDistribution(double[] distribution, Random rng){
    double d = 0.0;
    double sum = 0.0;
    for( int t=0; t<10; t++ ) {
        d = rng.nextDouble();
        sum = 0.0;
        for( int i=0; i<distribution.length; i++ ){
            sum += distribution[i];
            if( d <= sum ) return i;
        }
        //If we haven't found the right index yet, maybe the sum is slightly
        //lower than 1 due to rounding error, so try again.
    }
    //Should be extremely unlikely to happen if distribution is a valid probability distribution
    throw new IllegalArgumentException("Distribution is invalid? d="+d+", sum="+sum);
}

It seems that we are getting a random value. Why don't we just choose the index where the value is highest? What should I do if I want to select not one, but two or three most likely next characters?


Solution

  • This function samples from the distribution, instead of simply returning the most probable character class.

    That also means that you aren't getting the most likely character, instead, you are getting a random character with the probability that the given probability distribution defines.

    This works by first getting a random value between 0 and 1 from a uniform distribution (rng.nextDouble()) and then finding where that value falls in the given distribution.

    You can imagine it to be something like this (if your had only a to f in your alphabet):

     [   a    | b |   c   |   d    | e |     f     ] 
    0.0          0.3              0.5             1.0
    

    If the random value that is drawn is just over 0.5, it would produce an e, if it is just less than that it would be a d.

    Each letter occupies a proportional amount of space on this line between 0 and 1 according to the weight it has in the distribution.