I am trying to write my first neural network to play the game connect four. Im using Java and deeplearning4j. I tried to implement a genetic algorithm, but when i train the network for a while, the outputs of the network jump to NaN and I am unable to tell where I messed up so badly for this to happen.. I will post all 3 classes below, where Game is the game logic and rules, VGFrame the UI and Main all the nn stuff.
I have a pool of 35 neural networks and each iteration i let the best 5 live and breed and randomize the newly created ones a little. To evaluate the networks I let them battle each other and give points to the winner and points for loosing later. Since I penalize putting a stone into a column thats already full I expected the neural networks at least to be able to play the game by the rules after a while but they cant do this. I googled the NaN problem and it seems to be an expoding gradient problem, but from my understanding this shouldn't occur in a genetic algorithm? Any ideas where I could look for the error or whats generally wrong with my implementation?
Main
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
public class Main {
final int numRows = 7;
final int numColums = 6;
final int randSeed = 123;
MultiLayerNetwork[] models;
static Random random = new Random();
private static final Logger log = LoggerFactory.getLogger(Main.class);
final float learningRate = .8f;
int batchSize = 64; // Test batch size
int nEpochs = 1; // Number of training epochs
// --
public static Main current;
Game mainGame = new Game();
public static void main(String[] args) {
current = new Main();
current.frame = new VGFrame();
current.loadWeights();
}
private VGFrame frame;
private final double mutationChance = .05;
public Main() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).seed(randSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Nesterovs(0.1, 0.9))
.list()
.layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(new DenseLayer.Builder().nIn(30).nOut(15).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build())
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).nIn(15).nOut(7)
.activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build())
.build();
models = new MultiLayerNetwork[35];
for (int i = 0; i < models.length; i++) {
models[i] = new MultiLayerNetwork(conf);
models[i].init();
}
}
public void addChip(int i, boolean b) {
if (mainGame.gameState == 0)
mainGame.addChip(i, b);
if (mainGame.gameState == 0) {
float[] f = Main.rowsToInput(mainGame.rows);
INDArray input = Nd4j.create(f);
INDArray output = models[0].output(input);
for (int i1 = 0; i1 < 7; i1++) {
System.out.println(i1 + ": " + output.getDouble(i1));
}
System.out.println("----------------");
mainGame.addChip(Main.getHighestOutput(output), false);
}
getFrame().paint(getFrame().getGraphics());
}
public void newGame() {
mainGame = new Game();
getFrame().paint(getFrame().getGraphics());
}
public void startTraining(int iterations) {
// --------------------------
for (int gameNumber = 0; gameNumber < iterations; gameNumber++) {
System.out.println("Iteration " + gameNumber + " of " + iterations);
float[] evaluation = new float[models.length];
for (int i = 0; i < models.length; i++) {
for (int j = 0; j < models.length; j++) {
if (i != j) {
Game g = new Game();
g.playFullGame(models[i], models[j]);
if (g.gameState == 1) {
evaluation[i] += 45;
evaluation[j] += g.turnNumber;
}
if (g.gameState == 2) {
evaluation[j] += 45;
evaluation[i] += g.turnNumber;
}
}
}
}
float[] evaluationSorted = evaluation.clone();
Arrays.sort(evaluationSorted);
// keep the best 4
int n1 = 0, n2 = 0, n3 = 0, n4 = 0, n5 = 0;
for (int i = 0; i < evaluation.length; i++) {
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 1])
n1 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 2])
n2 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 3])
n3 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 4])
n4 = i;
if (evaluation[i] == evaluationSorted[evaluationSorted.length - 5])
n5 = i;
}
models[0] = models[n1];
models[1] = models[n2];
models[2] = models[n3];
models[3] = models[n4];
models[4] = models[n5];
for (int i = 3; i < evaluationSorted.length; i++) {
// random parent/keep w8ts
double r = Math.random();
if (r > .3) {
models[i] = models[random.nextInt(3)].clone();
} else if (r > .1) {
models[i].setParams(breed(models[random.nextInt(3)], models[random.nextInt(3)]));
}
// Mutate
INDArray params = models[i].params();
models[i].setParams(mutate(params));
}
}
}
private INDArray mutate(INDArray params) {
double[] d = params.toDoubleVector();
for (int i = 0; i < d.length; i++) {
if (Math.random() < mutationChance)
d[i] += (Math.random() - .5) * learningRate;
}
return Nd4j.create(d);
}
private INDArray breed(MultiLayerNetwork m1, MultiLayerNetwork m2) {
double[] d = m1.params().toDoubleVector();
double[] d2 = m2.params().toDoubleVector();
for (int i = 0; i < d.length; i++) {
if (Math.random() < .5)
d[i] += d2[i];
}
return Nd4j.create(d);
}
static int getHighestOutput(INDArray output) {
int x = 0;
for (int i = 0; i < 7; i++) {
if (output.getDouble(i) > output.getDouble(x))
x = i;
}
return x;
}
static float[] rowsToInput(byte[][] rows) {
float[] f = new float[7 * 6];
for (int i = 0; i < 6; i++) {
for (int j = 0; j < 7; j++) {
// f[j + i * 7] = rows[j][i] / 2f;
f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
}
}
return f;
}
public void saveWeights() {
log.info("Saving model");
for (int i = 0; i < models.length; i++) {
File resourcesDirectory = new File("src/resources/model" + i);
try {
models[i].save(resourcesDirectory, true);
} catch (IOException e) {
e.printStackTrace();
}
}
}
public void loadWeights() {
if (new File("src/resources/model0").exists()) {
for (int i = 0; i < models.length; i++) {
File resourcesDirectory = new File("src/resources/model" + i);
try {
models[i] = MultiLayerNetwork.load(resourcesDirectory, true);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
System.out.println("col: " + models[0].params().shapeInfoToString());
}
public VGFrame getFrame() {
return frame;
}
}
VGFrame
import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JTextField;
public class VGFrame extends JFrame {
JTextField iterations;
/**
*
*/
private static final long serialVersionUID = 1L;
public VGFrame() {
super("Vier Gewinnt");
this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
this.setSize(1300, 800);
this.setVisible(true);
JPanel panelGame = new JPanel();
panelGame.setBorder(BorderFactory.createLineBorder(Color.black, 2));
this.add(panelGame);
var handler = new Handler();
var menuHandler = new MenuHandler();
JButton b1 = new JButton("1");
JButton b2 = new JButton("2");
JButton b3 = new JButton("3");
JButton b4 = new JButton("4");
JButton b5 = new JButton("5");
JButton b6 = new JButton("6");
JButton b7 = new JButton("7");
b1.addActionListener(handler);
b2.addActionListener(handler);
b3.addActionListener(handler);
b4.addActionListener(handler);
b5.addActionListener(handler);
b6.addActionListener(handler);
b7.addActionListener(handler);
panelGame.add(b1);
panelGame.add(b2);
panelGame.add(b3);
panelGame.add(b4);
panelGame.add(b5);
panelGame.add(b6);
panelGame.add(b7);
JButton buttonTrain = new JButton("Train");
JButton buttonNewGame = new JButton("New Game");
JButton buttonSave = new JButton("Save Weights");
JButton buttonLoad = new JButton("Load Weights");
iterations = new JTextField("1000");
buttonTrain.addActionListener(menuHandler);
buttonNewGame.addActionListener(menuHandler);
buttonSave.addActionListener(menuHandler);
buttonLoad.addActionListener(menuHandler);
iterations.addActionListener(menuHandler);
panelGame.add(iterations);
panelGame.add(buttonTrain);
panelGame.add(buttonNewGame);
panelGame.add(buttonSave);
panelGame.add(buttonLoad);
this.validate();
}
@Override
public void paint(Graphics g) {
super.paint(g);
if (Main.current.mainGame.rows == null)
return;
var rows = Main.current.mainGame.rows;
for (int i = 0; i < rows.length; i++) {
for (int j = 0; j < rows[0].length; j++) {
if (rows[i][j] == 0)
break;
g.setColor((rows[i][j] == 1 ? Color.yellow : Color.red));
g.fillOval(80 + 110 * i, 650 - 110 * j, 100, 100);
}
}
}
public void update() {
}
}
class Handler implements ActionListener {
@Override
public void actionPerformed(ActionEvent event) {
if (Main.current.mainGame.playersTurn)
Main.current.addChip(Integer.parseInt(event.getActionCommand()) - 1, true);
}
}
class MenuHandler implements ActionListener {
@Override
public void actionPerformed(ActionEvent event) {
switch (event.getActionCommand()) {
case "New Game":
Main.current.newGame();
break;
case "Train":
Main.current.startTraining(Integer.parseInt(Main.current.getFrame().iterations.getText()));
break;
case "Save Weights":
Main.current.saveWeights();
break;
case "Load Weights":
Main.current.loadWeights();
break;
}
}
}
Game
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class Game {
int turnNumber = 0;
byte[][] rows = new byte[7][6];
boolean playersTurn = true;
int gameState = 0; // 0:running, 1:Player1, 2:Player2, 3:Draw
public boolean isRunning() {
return this.gameState == 0;
}
public void addChip(int x, boolean player1) {
turnNumber++;
byte b = nextRow(x);
if (b == 6) {
gameState = player1 ? 2 : 1;
return;
}
rows[x][b] = (byte) (player1 ? 1 : 2);
gameState = checkWinner(x, b);
}
private byte nextRow(int x) {
for (byte i = 0; i < rows[x].length; i++) {
if (rows[x][i] == 0)
return i;
}
return 6;
}
// 0 continue, 1 Player won, 2 ai won, 3 Draw
private int checkWinner(int x, int y) {
int color = rows[x][y];
// Vertikal
if (getCount(x, y, 1, 0) + getCount(x, y, -1, 0) >= 3)
return rows[x][y];
// Horizontal
if (getCount(x, y, 0, 1) + getCount(x, y, 0, -1) >= 3)
return rows[x][y];
// Diagonal1
if (getCount(x, y, 1, 1) + getCount(x, y, -1, -1) >= 3)
return rows[x][y];
// Diagonal2
if (getCount(x, y, -1, 1) + getCount(x, y, 1, -1) >= 3)
return rows[x][y];
for (byte[] bs : rows) {
for (byte s : bs) {
if (s == 0)
return 0;
}
}
return 3; // Draw
}
private int getCount(int x, int y, int dirX, int dirY) {
int color = rows[x][y];
int count = 0;
while (true) {
x += dirX;
y += dirY;
if (x < 0 | x > 6 | y < 0 | y > 5)
break;
if (color != rows[x][y])
break;
count++;
}
return count;
}
public void playFullGame(MultiLayerNetwork m1, MultiLayerNetwork m2) {
boolean player1 = true;
while (this.gameState == 0) {
float[] f = Main.rowsToInput(this.rows);
INDArray input = Nd4j.create(f);
this.addChip(Main.getHighestOutput(player1 ? m1.output(input) : m2.output(input)), player1);
player1 = !player1;
}
}
}
With a quick look, and based on the analysis of your multiplier variants, it seems like the NaN
is produced by an arithmetic underflow, caused by your gradients being too small (too close to absolute 0).
This is the most suspicious part of the code:
f[j + i * 7] = (rows[j][i] == 0 ? .5f : rows[j][i] == 1 ? 0f : 1f);
If rows[j][i] == 1
then 0f
is stored. I don't know how this is managed by the neural network (or even java), but mathematically speaking, a finite-sized float cannot include zero.
Even if your code would alter the 0f
with some extra salt, those array values' resultants would have some risk of becoming too close to zero. Due to limited precision when representing real numbers, values very close to zero can not be represented, hence the NaN
.
These values have a very friendly name: subnormal numbers.
Any non-zero number with magnitude smaller than the smallest normal number is subnormal.
As with IEEE 754-1985, The standard recommends 0 for signaling NaNs, 1 for quiet NaNs, so that a signaling NaNs can be quieted by changing only this bit to 1, while the reverse could yield the encoding of an infinity.
Above's text is important here: according to the standard, you are actually specifying a NaN
with any 0f
value stored.
Even if the name is misleading, Float.MIN_VALUE
is a positive value,higher than 0:
The real minimum float
value is, in fact: -Float.MAX_VALUE
.
Is floating point math subnormal?
If you check the issue is only because of the 0f
values, you could just alter them for other values that represent something similar; Float.MIN_VALUE
, Float.MIN_NORMAL
, and so on. Something like this, also in other possible parts of the code where this scenario could happen. Take these just as examples, and play with these ranges:
rows[j][i] == 1 ? Float.MIN_VALUE : 1f;
rows[j][i] == 1 ? Float.MIN_NORMAL : Float.MAX_VALUE/2;
rows[j][i] == 1 ? -Float.MAX_VALUE/2 : Float.MAX_VALUE/2;
Even so, this could also lead to a NaN
, based on how these values are altered.
If so, you should normalize the values. You could try applying a GradientNormalizer for this. In your network initialization, something like this should be defined, for each layer(or for those who are problematic):
new NeuralNetConfiguration
.Builder()
.weightInit(WeightInit.XAVIER)
(...)
.layer(new DenseLayer.Builder().nIn(42).nOut(30).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) //this
.build())
(...)
There are different normalizers, so choose which one fits your schema best, and which layers should include one. The options are:
RenormalizeL2PerLayer
Rescale gradients by dividing by the L2 norm of all gradients for the layer.
RenormalizeL2PerParamType
Rescale gradients by dividing by the L2 norm of the gradients, separately for each type of parameter within the layer. This differs from RenormalizeL2PerLayer in that here, each parameter type (weight, bias etc) is normalized separately. For example, in a MLP/FeedForward network (where G is the gradient vector), the output is as follows:
GOut_weight = G_weight / l2(G_weight) GOut_bias = G_bias / l2(G_bias)
ClipElementWiseAbsoluteValue
Clip the gradients on a per-element basis. For each gradient g, set g <- sign(g) max(maxAllowedValue,|g|). i.e., if a parameter gradient has absolute value greater than the threshold, truncate it. For example, if threshold = 5, then values in range -5<g<5 are unmodified; values <-5 are set to -5; values >5 are set to 5.
ClipL2PerLayer
Conditional renormalization. Somewhat similar to RenormalizeL2PerLayer, this strategy scales the gradients if and only if the L2 norm of the gradients (for entire layer) exceeds a specified threshold. Specifically, if G is gradient vector for the layer, then:
GOut = G if l2Norm(G) < threshold (i.e., no change) GOut = threshold * G / l2Norm(G)
ClipL2PerParamType
Conditional renormalization. Very similar to ClipL2PerLayer, however instead of clipping per layer, do clipping on each parameter type separately. For example in a recurrent neural network, input weight gradients, recurrent weight gradients and bias gradient are all clipped separately.
Here you can find a complete example of the application of these GradientNormalizers
.