The FFT in Matlab does not allow to choose how many threads are doing the computation (http://stackoverflow.com/questions/9528833/matlabs-fftn-gets-slower-with-multithreading). By default its uses all the cores on a standalone matlab. But on a cluster, each worker is launched with a single CPU by default. You can force it to work with more cores (maxNumCompThreads function). This works perfectly with algebric operations but the FFT function remains (weirdly?) single core. I thus wrote a mex file using the fftw library (as matlab does) to compute the fft with the number of cores desired. But when I try to compare the codes using the FFTW_ESTIMATE planner (which is the default in Matlab) and a clear wisdom, my code remains 3 to 4 times slower than the Matlab fft.
Here is the code I used for the mex (applied for 2D fft, named FFT2mx):
#include <stdlib.h>
#include <stdio.h>
#include <mex.h>
#include <matrix.h>
#include <math.h>
#include </home/nicolas/Code/C/lib/include/fftw3.h>
void FFTNDSplit(int NumDims, const int N[], double *XReal, double *XImag, double *YReal, double *YImag, int Sign)
{
fftw_plan Plan;
fftw_iodim Dim[NumDims];
int k, NumEl;
for(k = 0, NumEl = 1; k < NumDims; k++)
{
Dim[NumDims - k - 1].n = N[k];
Dim[NumDims - k - 1].is = Dim[NumDims - k - 1].os = (k == 0) ? 1 : (N[k-1] * Dim[NumDims-k].is);
NumEl *= N[k];
}
//fftw_import_wisdom_from_filename("/home/nicolas/wisdom/wis");
if(!(Plan = fftw_plan_guru_split_dft(NumDims, Dim, 0, NULL, XReal,
XImag, YReal, YImag, FFTW_ESTIMATE)))
mexErrMsgTxt("FFTW3 failed to create plan.");
if(Sign == -1)
fftw_execute_split_dft(Plan, XReal, XImag, YReal, YImag);
else
{
fftw_execute_split_dft(Plan, XImag, XReal, YImag, YReal);
}
//if(!fftw_export_wisdom_to_filename("/home/nicolas/wisdom/wis"))
// mexErrMsgTxt("FFTW3 failed to save wisdom.");
fftw_destroy_plan(Plan);
return;
}
void mexFunction( int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[] )
{
int i, j,numCPU;
int NumDims;
const mwSize *N;
if (nrhs != 2) {
mexErrMsgIdAndTxt( "MATLAB:FFT2mx:invalidNumInputs",
"Two input argument required.");
}
if (!mxIsDouble(prhs[0])) {
mexErrMsgIdAndTxt( "MATLAB:FFT2mx:invalidNumInputs",
"Array must be double");
}
numCPU = (int) mxGetScalar(prhs[1]);
if (numCPU > 8) {
mexErrMsgIdAndTxt( "MATLAB:FFT2mx:invalidNumInputs",
"NumOfThreads < 8 requested");
}
/*if (!mxIsComplex(prhs[0])) {
mexErrMsgIdAndTxt( "MATLAB:FFT2mx:invalidNumInputs",
"Array must be complex");
}*/
NumDims = mxGetNumberOfDimensions(prhs[0]);
N = mxGetDimensions(prhs[0]);
plhs[0] = mxCreateDoubleMatrix(0, 0, mxCOMPLEX);
mxSetDimensions(plhs[0], N, NumDims);
mxSetData(plhs[0], mxMalloc( sizeof(double) * mxGetNumberOfElements(prhs[0]) ));
mxSetImagData(plhs[0], mxMalloc( sizeof(double) * mxGetNumberOfElements(prhs[0]) ));
fftw_init_threads();
fftw_plan_with_nthreads(numCPU);
FFTNDSplit(NumDims, N, (double *) mxGetPr(prhs[0]), (double *) mxGetPi(prhs[0]),
mxGetPr(plhs[0]), mxGetPi(plhs[0]), -1);
}
The associated matlab code:
function fft2mx(X,NumCPU)
FFT2mx(X,NumCPU)/sqrt(size(X,1)*size(X,2));
return;
I compile the mex code using the static libraries:
mex FFT2mx.cpp /home/nicolas/Code/C/lib/lib/libfftw3.a /home/nicolas/Code/C/lib/lib/libfftw3_threads.a
Everything works well, it is just slower.
The FFTW library has been compile with the following arguments:
CC="gcc ${BUILD64} -fPIC" CXX="g++ ${BUILD64} -fPIC" \
./configure --prefix=/home/nicolas/Code/C/lib --enable-threads &&
make
make install
I am running this code on one cluster node with 2 Quad-Core AMD Opteron(tm) and I test with:
A = randn([2048 2048])+ i*randn([2048 2048]);
tic, fft2mx(A,8); toc;
tic, fftn(A); toc;
witch returns:
Elapsed time is 0.482021 seconds.
Elapsed time is 0.151630 seconds.
How my mex code can be tuned ? Is the compilation of the fftw library can be optimized ? Is there a way to speed up the fftw algorithm in using only the ESTIMATE planner ?
I am looking for any insights. Thank you.
EDIT:
I take into account what you suggested (using wisdom and static plan) and wrote this updated code:
# include <string.h>
# include <stdlib.h>
# include <stdio.h>
# include <mex.h>
# include <matrix.h>
# include <math.h>
# include </home/nicolas/Code/C/lib/include/fftw3.h>
char *Wisfile = NULL;
char *Wistemplate = "%s/.fftwis";
#define WISLEN 8
void set_wisfile(void)
{
char *home;
if (Wisfile) return;
home = getenv("HOME");
Wisfile = (char *)malloc(strlen(home) + WISLEN + 1);
sprintf(Wisfile, Wistemplate, home);
}
void cleanup(void) {
static fftw_plan PlanForward;
static int planlen;
static double *pr, *pi, *pr2, *pi2;
mexPrintf("MEX-file is terminating, destroying array\n");
fftw_destroy_plan(PlanForward);
fftw_free(pr2);
fftw_free(pi2);
fftw_free(pr);
fftw_free(pi);
}
void mexFunction( int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[] )
{
int i, j, numCPU, NumDims;
const mwSize *N;
fftw_complex *out, *in1;
static double *pr, *pi, *pr2, *pi2;
static int planlen = 0;
static fftw_plan PlanForward;
fftw_iodim Dim[NumDims];
int k, NumEl;
FILE *wisdom;
if (nrhs != 2) {
mexErrMsgIdAndTxt( "MATLAB:FFT2mx:invalidNumInputs",
"Two input argument required.");
}
if (!mxIsDouble(prhs[0])) {
mexErrMsgIdAndTxt( "MATLAB:FFT2mx:invalidNumInputs",
"Array must be double");
}
numCPU = (int) mxGetScalar(prhs[1]);
if (numCPU > 8) {
mexErrMsgIdAndTxt( "MATLAB:FFT2mx:invalidNumInputs",
"NumOfThreads < 8 requested");
}
if (!mxIsComplex(prhs[0])) {
mexErrMsgIdAndTxt( "MATLAB:FFT2mx:invalidNumInputs",
"Array must be complex");
}
NumDims = mxGetNumberOfDimensions(prhs[0]);
N = mxGetDimensions(prhs[0]);
for(k = 0, NumEl = 1; k < NumDims; k++)
{
Dim[NumDims - k - 1].n = N[k];
Dim[NumDims - k - 1].is = Dim[NumDims - k - 1].os = (k == 0) ? 1 : (N[k-1] * Dim[NumDims-k].is);
NumEl *= N[k];
}
/* If different size, free/destroy */
if(N[0] != planlen && planlen > 0) {
fftw_free(pr2);
fftw_free(pi2);
fftw_free(pr);
fftw_free(pi);
fftw_destroy_plan(PlanForward);
planlen = 0;
}
mexAtExit(cleanup);
/* Init */
fftw_init_threads();
// APPROACH 1
//pr = (double *) mxGetPr(prhs[0]);
//pi = (double *) mxGetPi(prhs[0]);
// APPROACH 2
pr = (double *) fftw_malloc( sizeof(double) * mxGetNumberOfElements(prhs[0]) );
pi = (double *) fftw_malloc( sizeof(double) * mxGetNumberOfElements(prhs[0]) );
tmp1 = (double *) mxGetPr(prhs[0]);
tmp2 = (double *) mxGetPi(prhs[0]);
for(k=0;k<mxGetNumberOfElements(prhs[0]);k++)
{
pr[k] = tmp1[k];
pi[k] = tmp2[k];
}
plhs[0] = mxCreateNumericMatrix(0, 0, mxDOUBLE_CLASS, mxCOMPLEX);
mxSetDimensions(plhs[0], N, NumDims);
mxSetData(plhs[0], (double* ) fftw_malloc( sizeof(double) * mxGetNumberOfElements(prhs[0]) ));
mxSetImagData(plhs[0], (double* ) fftw_malloc( sizeof(double) * mxGetNumberOfElements(prhs[0]) ));
pr2 = mxGetPr(plhs[0]);
pi2 = mxGetPi(plhs[0]);
fftw_init_threads();
fftw_plan_with_nthreads(numCPU);
/* Get any accumulated wisdom. */
set_wisfile();
wisdom = fopen(Wisfile, "r");
if (wisdom) {
fftw_import_wisdom_from_file(wisdom);
fclose(wisdom);
}
/* Compute plan */
//printf("%d",planlen);
if(planlen == 0 ) {
fftw_plan_with_nthreads(numCPU);
PlanForward = fftw_plan_guru_split_dft(NumDims, Dim, 0, NULL, pr, pi, pr2, pi2, FFTW_MEASURE);
planlen = N[0];
}
/* Save the wisdom. */
wisdom = fopen(Wisfile, "w");
if (wisdom) {
fftw_export_wisdom_to_file(wisdom);
fclose(wisdom);
}
/* execute */
fftw_execute_split_dft(PlanForward, pr, pi, pr2, pi2);
fftw_cleanup_threads();
}
I am now encountering some segmentation faults after several calls (betweem 2 to 6) to the function and I cannot figure out why. I tried different way to initialize by pointer. I read also somewhere that the pointer of the plan have to be static to work with the corresponding static plan. Anything you see I am doing wrong ?
Thanks again for your insights.
The problem is that you are creating and destroying a plan for each FFT. Creating a plan is typically much more time-consuming than the FFT itself. Ideally you only create and destroy a plan once and then re-use it a number of times for successive FFTs of the same dimension(s).
If you are calling your MEX repeatedly for the same size FFT then you may be able to memoize the plan (e.g. keep a static plan variable and dimension and only recreate the plan as needed, i.e. when the dimension changes).
Alternatively you could have three MEX functions - one for creating a plan, one for running the FFT with a given plan and one for destroying the plan.
Once you have fixed the above architectural problem you should consider using FFTW_MEASURE
instead of FFTW_ESTIMATE
for better performance.
One further thing: you might want to add --enable-sse
to your ./configure
command to enable SIMD code generation in the FFTW butterflies.