Search code examples
javajava-8deep-learningdeeplearning4j

RL4J A3C DeepLearning Throwing a Output from network is not a probability distribution


So right now I am taking the painful dive of exploring deep learning using Deep Learning 4j specifically RL4j and reinforcement learning. I have been relatively unsuccessful in teaching my computer how to play snake but I persevere.

Anyway so I have been running into a problem that I can't solve I'll set my program to run while I go to sleep or am at work (Yes I work in an essential industry) and when I check back it has thrown this error on all running threads and the program has completely stopped, mind you this usually happens about an hour into training.

Exception in thread "Thread-8" java.lang.RuntimeException: Output from network is not a probability distribution: [[         ?,         ?,         ?]]
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:82)
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:37)
at org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete.trainSubEpoch(AsyncThreadDiscrete.java:96)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.handleTraining(AsyncThread.java:144)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.run(AsyncThread.java:121)

Here is how I am setting up my network

    private static A3CDiscrete.A3CConfiguration CARTPOLE_A3C =
        new A3CDiscrete.A3CConfiguration(
                (new java.util.Random()).nextInt(),            //Random seed
                220,            //Max step By epoch
                500000,         //Max step
                6,              //Number of threads
                50,              //t_max
                75,             //num step noop warmup
                0.1,           //reward scaling
                0.987,           //gamma
                1.0           //td-error clipping
        );


private static final ActorCriticFactorySeparateStdDense.Configuration CARTPOLE_NET_A3C =  ActorCriticFactorySeparateStdDense.Configuration
.builder().updater(new Adam(.005)).l2(.01).numHiddenNodes(32).numLayer(3).build();

Also the Input to my network is the entire grid for my snake game 16x16 put into a single double array.

Incase it has something to do with my reward function here is that

if(!snake.inGame()) {
        return -5.3; //snake dies 
    }
    if(snake.gotApple()) {
        return 5.0+.37*(snake.getLength()); //snake gets apple
    }
    return 0; //survives

My Question is How do stop this error from occurring? I truly have no idea what is happening and its been making building my network rather difficult, yes I have already checked the web for answers all that comes up is like 2 GitHub tickets from 2018.

If it's of interest so you don't have to go digging here is the function from ACPolicy that is throwing the error

 public Integer nextAction(INDArray input) {
    INDArray output = actorCritic.outputAll(input)[1];
    if (rnd == null) {
        return Learning.getMaxAction(output);
    }
    float rVal = rnd.nextFloat();
    for (int i = 0; i < output.length(); i++) {
        //System.out.println(i + " " + rVal + " " + output.getFloat(i));
        if (rVal < output.getFloat(i)) {
            return i;
        } else
            rVal -= output.getFloat(i);
    }

    throw new RuntimeException("Output from network is not a probability distribution: " + output);
}

Any help that you can offer is greatly appreciated


Solution

  • What you are seeing is that your network is running into NaN's. That is what the question marks in the exception mean. There are many reasons why that may happen. You say, you are running it for quite a while, so it may be that you get under- or overflows at some point. Some regularization may help or some gradient clipping.

    However, RL4J itself is being reworked as of beta6 and should be in a whole lot better state come next release.

    If you want to try the current state, there are snapshots you can use and there is also a working A3C example at https://github.com/RobAltena/cartpole/blob/master/src/main/java/A3CCartpole.java

    For some more thorough help, you should probably take a look at the DL4J community forum at community.konduit.ai . It is more suited to the back and forth that is likely needed to help you build a successful AI for your snake game.