Search code examples
javaalgorithmneural-networkgenetic-algorithmencog

How can I pause/serialize a genetic algorithm in Encog?


How can I pause a genetic algorithm in Encog 3.4 (the version currently under development in Github)?

I am using the Java version of Encog.

I am trying to modify the Lunar example that comes with Encog. I want to pause/serialize the genetic algorithm and then continue/deserialize at a later stage.

When I call train.pause(); it simply returns null - which is pretty obvious from the code since the method always returns null.

I would assume that it would be pretty straight forward since there can be a scenario in which I want to train a neural network, use it for some predictions and then continue training with the genetic algorithm as I get more data before resuming with more predictions - without having to restart the training from the beginning.

Please note that I am not trying to serialize or persist a neural network but rather the entire genetic algorithm.


Solution

  • Not all trainers in Encog support the simple pause/resume. If they do not support it, they return null, like this one. The genetic algorithm trainer is much more complex than a simple propagation trainer that supports pause/resume. To save the state of the genetic algorithm, you must save the entire population, as well as the scoring function (which may or may not be serializable). I modified the Lunar Lander example to show you how you might save/reload your population of neural networks to do this.

    You can see that it trains 50 iterations, then round-trips (load/saves) the genetic algorithm, then trains 50 more.

    package org.encog.examples.neural.lunar;
    
    import java.io.File;
    import java.io.IOException;
    
    import org.encog.Encog;
    import org.encog.engine.network.activation.ActivationTANH;
    import org.encog.ml.MLMethod;
    import org.encog.ml.MLResettable;
    import org.encog.ml.MethodFactory;
    import org.encog.ml.ea.population.Population;
    import org.encog.ml.genetic.MLMethodGeneticAlgorithm;
    import org.encog.ml.genetic.MLMethodGenomeFactory;
    import org.encog.neural.networks.BasicNetwork;
    import org.encog.neural.pattern.FeedForwardPattern;
    import org.encog.util.obj.SerializeObject;
    
    public class LunarLander {
    
        public static BasicNetwork createNetwork()
        {
            FeedForwardPattern pattern = new FeedForwardPattern();
            pattern.setInputNeurons(3);
            pattern.addHiddenLayer(50);
            pattern.setOutputNeurons(1);
            pattern.setActivationFunction(new ActivationTANH());
            BasicNetwork network = (BasicNetwork)pattern.generate();
            network.reset();
            return network;
        }
    
        public static void saveMLMethodGeneticAlgorithm(String file, MLMethodGeneticAlgorithm ga ) throws IOException
        {
            ga.getGenetic().getPopulation().setGenomeFactory(null);
            SerializeObject.save(new File(file),ga.getGenetic().getPopulation());   
        }
    
        public static MLMethodGeneticAlgorithm loadMLMethodGeneticAlgorithm(String filename) throws ClassNotFoundException, IOException {
            Population pop = (Population) SerializeObject.load(new File(filename));
            pop.setGenomeFactory(new MLMethodGenomeFactory(new MethodFactory(){
                @Override
                public MLMethod factor() {
                    final BasicNetwork result = createNetwork();
                    ((MLResettable)result).reset();
                    return result;
                }},pop));
    
            MLMethodGeneticAlgorithm result = new MLMethodGeneticAlgorithm(new MethodFactory(){
                @Override
                public MLMethod factor() {
                    return createNetwork();
                }},new PilotScore(),1);
    
            result.getGenetic().setPopulation(pop);
    
            return result;
        }
    
    
        public static void main(String args[])
        {
            BasicNetwork network = createNetwork();
    
            MLMethodGeneticAlgorithm train;
    
    
            train = new MLMethodGeneticAlgorithm(new MethodFactory(){
                @Override
                public MLMethod factor() {
                    final BasicNetwork result = createNetwork();
                    ((MLResettable)result).reset();
                    return result;
                }},new PilotScore(),500);
    
            try {
                int epoch = 1;
    
                for(int i=0;i<50;i++) {
                    train.iteration();
                    System.out
                            .println("Epoch #" + epoch + " Score:" + train.getError());
                    epoch++;
                } 
                train.finishTraining();
    
                // Round trip the GA and then train again
                LunarLander.saveMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin",train);
                train = LunarLander.loadMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin");
    
                // Train again
                for(int i=0;i<50;i++) {
                    train.iteration();
                    System.out
                            .println("Epoch #" + epoch + " Score:" + train.getError());
                    epoch++;
                } 
                train.finishTraining();
    
            } catch(IOException ex) {
                ex.printStackTrace();
            } catch (ClassNotFoundException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
    
            int epoch = 1;
    
            for(int i=0;i<50;i++) {
                train.iteration();
                System.out
                        .println("Epoch #" + epoch + " Score:" + train.getError());
                epoch++;
            } 
            train.finishTraining();
    
            System.out.println("\nHow the winning network landed:");
            network = (BasicNetwork)train.getMethod();
            NeuralPilot pilot = new NeuralPilot(network,true);
            System.out.println(pilot.scorePilot());
            Encog.getInstance().shutdown();
        }
    }