Search code examples
c#algorithmmathpseudocodetranslate

Cross-Entropy Method translation from mathematical notation


I want to add the Cross-Entropy method for parameter selection in an algorithm I'm using. The problem is that I don't understand mathematical notation very well and I can't find this version of the cross-entropy method written in code anywhere.

The algorithm, in pseudo code, can be seen in this image:

https://i.sstatic.net/jOQtc.png (I can't paste it here because it has lots of latex)

It was taken from this paper: https://project.dke.maastrichtuniversity.nl/games/files/phd/Chaslot_thesis.pdf (page 69)

Could you help me translate it into c# or any other language or even into English?

Thanks!


Solution

  • After Robert Dodier clarifications, which help me in some ways, but in others made me even more confused, I went back to a ruby code for the cross entropy method I had seen but thought it wasn't the exact same algorithm I was trying to "translate". With the new found knowledge (from the clarifications) I saw that it was indeed the same algorithm and translated it into c#.

    Original ruby code: http://www.cleveralgorithms.com/nature-inspired/probabilistic/cross_entropy.html

    My translation into c#:

    class CrossEntropyMethod
    {
        Random r = new Random();
        double objective_function(double[] vector)
        {
            double sum=0f;
            foreach (var f in vector)
            {
                sum+=(double)Math.Pow(f,2);
            }
            return -sum;
        }
    
        double QuadraticEquation(double[] vector)
        {
            // 5X^2 + 10X - 2 = 0 -> X=-2.183216 || X=0.183216
            double sum = 5 * Math.Pow(vector[0],2) + 10 * vector[0] - 2;
            return - Math.Abs(sum);
        }
        double QuadraticEquation2(double[] vector)
        {
            // 5X^2 + 10X - 2 = 0 -> X=-2.183216 || X=0.183216
            double sum1 = vector[0] * Math.Pow(0.183216, 2) + vector[1] * 0.183216 + vector[2];
            double sum2 = vector[0] * Math.Pow(-2.183216, 2) + vector[1] * -2.183216 + vector[2];
            return - (Math.Abs(sum1) + Math.Abs(sum2));
        }
    
        double random_variable(double min, double max)
        { 
            return min + ((max - min) * r.NextDouble());
        }
    
        double random_gaussian(double mean=0.0, double stdev=1.0)
        {
          double u1, u2, w;
          u1 = u2 = w = 0;
          do{
            u1 = 2 * r.NextDouble() - 1;
            u2 = 2 * r.NextDouble() - 1;
            w = u1 * u1 + u2 * u2;
          } while (w >= 1);
    
          w = Math.Sqrt((-2.0 * Math.Log(w)) / w);
          return mean + (u2 * w) * stdev;
        }
    
        double[] generate_sample(double[][] search_space, double[] means, double[] stdevs)
        {
          double[] vector = new double[search_space.Length];
    
              for (int i=0; i<vector.Length; i++)
              {
                vector[i] = random_gaussian(means[i], stdevs[i]);
                vector[i] = Math.Max(vector[i] ,search_space[i][0]);
                vector[i] = Math.Min(vector[i], search_space[i][1]);
            }
    
          return vector;
        }
    
        void update_distribution(double[][] samples, double alpha, ref double[] means, ref double[] stdevs)
        {
            for (int i=0; i< means.Length; i++)
            {
                double[] tArray = new double[samples.Length];
                for (int z = 0; z < samples.Length; z++)
                {
                    tArray[z] = samples[z][i];
                }
                means[i] = alpha * means[i] + ((1.0 - alpha) * tArray.Average());
                stdevs[i] = alpha * stdevs[i] + ((1.0 - alpha) * MyExtensions.StandardDeviation(tArray));
            }
        }
    
        double[] search(double[][] bounds, int max_iter, int num_samples, int num_update, double learning_rate)
        {
            double[] means = new double[bounds.Length];
            double[] stdevs = new double[bounds.Length];
            for (int i=0; i< means.Count(); i++)
            {
                means[i]=random_variable(bounds[i][0], bounds[i][1]);
                stdevs[i]=bounds[i][1]-bounds[i][0];
            }
            double[] best=null;
            double bestScore=double.MinValue;
            for (int t=0; t<max_iter; t++)
            {
                double[][] samples= new double[num_samples][];
                double[] scores=new double[num_samples];
                for (int s=0; s<num_samples; s++)
                {
                    samples[s]=generate_sample(bounds, means, stdevs);
                    scores[s]=QuadraticEquation(samples[s]);
                }
                Array.Sort(scores,samples);
                Array.Reverse(scores);
                Array.Reverse(samples);
                if (best==null || scores.First() > bestScore)
                {
                    bestScore=scores.First();
                    best=samples.First();
                }
                double[][] selected = new double[num_update][];
                Array.Copy(samples,selected,num_update);
                update_distribution(selected, learning_rate, ref means, ref stdevs);
                Console.WriteLine("iteration={0}, fitness={1}", t, bestScore);
            }
          return best;
        }
    
        public void Run()
        {
            double[][] parameters = new double[][] { new double[] { -500, 500 }}; //QuadraticEquation parameters
            //double[][] parameters = new double[][] { new double[] { 4, 6 }, new double[] { 9, 11 }, new double[] { -3, -1} }; //QuadraticEquation2 parameters
            //double[][] parameters = new double[][] { new double[] { -5, 5 }, new double[] { -5, 5 }, new double[] { -5, 5 } }; //object_function parameters
            int maxIter = 100;
            int nSamples = 50;
            int nUpdate = 5;
            double alpha = 1;
            double[] best = search(parameters, maxIter, nSamples, nUpdate, alpha);
            string str = string.Join(" | ", best.Select(a => a.ToString("N10")).ToArray());
            Console.WriteLine("Best: " + str);
        }
    }