Search code examples
fortranopenmpfftw

How to use multithreaded FFTW inside an OpenMP Fortran code


I have a code that works fine that's compiled w/ gcc/gfortran-14 (installed via brew). The vast majority of the time is spent doing FFT via FFTW. In the critical part of the code I have this:

!$omp parallel
!$omp sections
!$omp section
       call fft(xi3d,xo3d,.false.,xip,layered_p)
!$omp section
       call fft(yi3d,yo3d,.false.,yip,layered_p)
!$omp section
       call fft(zi3d,zo3d,.false.,zip,layered_p)
!$omp end sections
!$omp sections
!$omp section
       x3d=n3d(:,:,:,1,1)*xo3d+n3d(:,:,:,2,1)*yo3d+n3d(:,:,:,3,1)*zo3d
!$omp section
       y3d=n3d(:,:,:,4,1)*xo3d+n3d(:,:,:,5,1)*yo3d+n3d(:,:,:,6,1)*zo3d
!$omp section
       z3d=n3d(:,:,:,7,1)*xo3d+n3d(:,:,:,8,1)*yo3d+n3d(:,:,:,9,1)*zo3d
!$omp end sections
!$omp sections
!$omp section
       call fft(rx3d,x3d,.true.,xp,layered_p)
!$omp section
       call fft(ry3d,y3d,.true.,yp,layered_p)
!$omp section
       call fft(rz3d,z3d,.true.,zp,layered_p)
!$omp end sections
!$omp end parallel

where x/y/zi3d and rx/y/z3d are real 3d arrays and all others are complex 3d arrays (all created with fftw_alloc_real/complex). The fft subroutine simply calls fftw_plan_dft_c2r_3d or fftw_plan_dft_r2c_3d depending on the third parameter (i.e. .true. or .false.).

I usually run w/ three threads. so the FFT are run in parallel. This works really well and the code does see an almost 3x performance improvement.

Now I'm trying to use OpenMP with FFTW itself. That is tell FFTW to use 2 or 3 threads in each fft instead of one. So at the start of the code I do the required initialization:

     if (fftw_init_threads().eq.0) then
       print *,'fftw has problem: fftw_init_threads'
     else
       call fftw_plan_with_nthreads(max(1,omp_get_max_threads()/3))
       print *,'each fftw will use ',max(1,omp_get_max_threads()/3),' threads'
   !   call omp_set_nested(.true.)
     endif

However, when I run the code (even after setting export OMP_NUM_THREADS=6 or export OMP_NUM_THREADS=3,2, the code runs slower. Setting to 6, top tells me only three processors are running. The other setting shows like 4.5 processors running, but in both cases the code runs slower than just using 3 threads and not using FFTW OpenMP.

The size of the arrays I'm using is 512x512x256. I'm on a MacBookPro w/ M1Pro and 32GB of memory.

Do I need the omp_set_nested(.true.) or not? What can I do to see all 6 processors pegged while doing the FFTW (which, in this code is over 95% of the time!)?

Any suggestions on what to try.


Solution

  • From OpenMP perspective, you want to execute two nested parallel regions with 3 threads at the outer level which then execute n threads at the inner level. Starting with OpenMP 5.0, the use of nested-var and its API was deprecated and is now controlled either by setting max-active-levels-var, or by the elements in the nthreads-var list (3,n in your case). omp_get_max_threads() is defined to provide the head value of the nthreads-var list. So from serial context it will provide you 3.

    In your code, you therefore either tell fftw to plan for execution with 2 threads (OMP_NUM_THREADS=6), but then execute with a single thread, or you tell fftw to plan for execution with 1 thread (OMP_NUM_THREADS=3,2) and then execute with 2 threads.

    To plan properly for fftw execution, the tricky part is to access n from the list. You can spawn a parallel region to emulate your 3-sections region and pop the 3 from the nthreads-var list. At this point you can access the second value using omp_get_max_threads().

         if (fftw_init_threads().eq.0) then
           print *,'fftw has problem: fftw_init_threads'
         else
    !$omp parallel
    !$omp single
           call fftw_plan_with_nthreads(max(1,omp_get_max_threads()))
           print *,'each fftw will use ',max(1,omp_get_max_threads()),' threads'
    !$omp end single
    !$omp end parallel
         endif
    

    GCC implements nesting based on a nthreads-var list at least since version 9, so you don't need to additionally call omp_set_nested