Search code examples
matlabmex

How to convert a recursive function to mex code?


I have a recursive function choose in MATLAB code as follows:

   function nk=choose(n, k)
        if (k == 0)
            nk=1;
        else
            nk=(n * choose(n - 1, k - 1)) / k;
        end
    end

The code is used to compute the combination between n and k. I want to speed up it by using mex code. I tried to write a mex code as

double choose(double* n, double* k)
{
   if (k==0) 
        return 1;
   else
        return (n * choose(n - 1, k - 1)) / k;
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    double *n, *k, *nk;
    int mrows, ncols;
    plhs[0] = mxCreateDoubleMatrix(1,1, mxREAL);
    /* Assign pointers to each input and output. */
    n = mxGetPr(prhs[0]);    
    k = mxGetPr(prhs[1]);
    nk = mxGetPr(plhs[0]);
    /* Call the recursive function. */
    nk=choose(n,k);
}

However, it does not work. Could you help me to modify the mex code which can implement the above MATLAB code? Thanks


Solution

  • The following code fixes your C mex implementation.
    The problem is not the recursion of course...
    Your code uses pointers instead of values (in C it's important to use pointers only in the right places).

    You can use Matlab build in function: nchoosek
    See: http://www.mathworks.com/help/matlab/ref/nchoosek.html

    The following code works:

    //choose.c
    
    #include "mex.h"
    
    double choose(double n, double k)
    {
        if (k==0) 
        {
            return 1;
        }
        else
        {
            return (n * choose(n - 1, k - 1)) / k;
        }
    }
    
    
    void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
    {
        double *n, *k, *nk;
        int mrows, ncols;
        plhs[0] = mxCreateDoubleMatrix(1,1, mxREAL);
        /* Assign pointers to each input and output. */
        n = mxGetPr(prhs[0]);    
        k = mxGetPr(prhs[1]);
        nk = mxGetPr(plhs[0]);
    
        /* Call the recursive function. */
        //nk=choose(n,k);
        *nk = choose(*n, *k);
    }
    

    Compile it within Matlab: mex choose.c

    Execute:
    choose(10,5)
    ans =

    252

    It is not inefficient implementation...
    I am helping fixing your implementation, to be used as "inefficient example".


    Measure execution of rahnema1's implementation:
    tic;n = 1000000;k = 500000;nk = prod((k+1:n) .* prod((1:n-k).^ (-1/(n-k))));toc
    Elapsed time is 0.022855 seconds.

    Measure execution of choose.mexw64 implementation:
    tic;n = 1000000;k = 500000;nk = choose(1000000, 500000);toc
    Elapsed time is 0.007952 seconds.
    (took a little less time than prod((k+1:n) .* prod((1:n-k).^ (-1/(n-k))))).

    Measure Matlab recursion, getting error (even for n=700 and k=500):
    ic;n = 700;k = 500;nk = RecursiveFunctionTest(n, k);toc
    Maximum recursion limit of 500 reached. Use set(0,'RecursionLimit',N) to change the limit. Be aware that exceeding your available stack space can crash MATLAB and/or your computer.

    tic;n = 700;k = 400;nk = RecursiveFunctionTest(n, k);toc
    Elapsed time is 0.005635 seconds. Very inefficient...

    Measuring Matlab build in function nchoosek:
    tic;nchoosek(1000000, 500000);toc
    Warning: Result may not be exact. Coefficient is greater than 9.007199e+15 and is only accurate to 15 digits In nchoosek at 92 Elapsed time is 0.005081 seconds.

    Conclusion:
    You need to implement the C mex file without using recursion, and take a measure.


    Measure without recursion:

    static double factorial(double number) 
    {
        int x;
        double fac = 1;
    
        if (number == 0)
        {
            return 1.0;
        }    
    
        for (x = 2; x <= (int)number; x++)
        {
            fac = fac * x;
        }
    
        return fac;
    }
    
    
    
    double choose(double n, double k)
    {
        if (k == 0) 
        {
            return 1.0;
        }
        else
        {
            //n!/((n–k)! k!) 
            return factorial(n)/(factorial(n-k)*factorial(k));
        }
    }
    

    tic;choose(1000000, 500000);toc

    Elapsed time is 0.003079 seconds.

    Faster...