Search code examples
fftfftw

Can I use FFTW to make transform along first and last axis of 3D array?


Using fftw_plan_many_dft I can do transforms along x,y and y,z axis:

vector<complex<float_type>> yz_fft(vector<complex<float_type>> input, int N_X, int N_Y, int N_Z){
    vector<complex<float_type>> result(input.size());
    int rank = 2;
    int n[] = {N_Y,N_Z};
    int *inembed = n;
    int *onembed = n;
    int istride = 1;
    int ostride = 1;
    int idist = N_Y*N_Z;
    int odist = N_Y*N_Z;
    int howmany = N_X;
    fftw_plan plan = fftw_plan_many_dft(
            rank,
            n,
            howmany,
            reinterpret_cast<fftw_complex *>(input.data()),
            inembed,
            istride,
            idist,
            reinterpret_cast<fftw_complex *>(result.data()),
            onembed,
            ostride,
            odist,
            FFTW_FORWARD,
            FFTW_ESTIMATE);
    fftw_execute(plan);
    return result;
}

vector<complex<float_type>> xy_fft(vector<complex<float_type>> input, int N_X, int N_Y, int N_Z){
    vector<complex<float_type>> result(input.size());
    int rank = 2;
    int n[] = {N_X,N_Y};
    int *inembed = n;
    int *onembed = n;
    int istride = N_Z;
    int ostride = N_Z;
    int idist = 1;
    int odist = 1;
    int howmany = N_Z;
    fftw_plan plan = fftw_plan_many_dft(
            rank,
            n,
            howmany,
            reinterpret_cast<fftw_complex *>(input.data()),
            inembed,
            istride,
            idist,
            reinterpret_cast<fftw_complex *>(result.data()),
            onembed,
            ostride,
            odist,
            FFTW_FORWARD,
            FFTW_ESTIMATE);
    fftw_execute(plan);
    return result;
}

but I can't figure out how to do x,z transform. How do I do this?


Solution

  • So there is a way to use fftw_plan_many_dft to do xz transform. Downvotes may suggest that people are not interested in that but I decided to share it anyway. For solutnion check struct xz_fft_many below.

    #include <iostream>
    #include <numeric>
    #include <complex>
    #include <fftw3.h>
    
    #include <benchmark/benchmark.h>
    
    
    using namespace std;
    
    using float_type = double;
    using index_type = unsigned long;
    
    vector<complex<float_type>> get_data(index_type N){
    
        std::vector<complex<float_type>> data(N);
        iota(data.begin(), data.end(),0);
    
        return data;
    }
    
    void print(vector<complex<float_type>> data,index_type N_X,index_type N_Y,index_type N_Z){
        for(int i=0; i!=N_X; ++i){
            for(int j=0; j!=N_Y; ++j){
                for(int k=0; k!=N_Z; ++k){
                    index_type idx = i*(N_Y*N_Z)+j*N_Z+k;
                    cout<<"[ "<<i<<", "<<j<<", "<<k<<" ] = "<<data.data()[idx]<<endl;
                }
            }
        }
    }
    
    struct x_fft {
        vector<complex<float_type>>& data;
        vector<complex<float_type>> result;
        fftw_plan fft_plan;
        index_type N_X;
        index_type N_Y;
        index_type N_Z;
    
    
        x_fft(vector<complex<float_type>>& data,index_type N_X,index_type N_Y,index_type N_Z)
                : data(data), N_X(N_X), N_Y(N_Y), N_Z(N_Z)
        {
            result = vector<complex<float_type>>(data.size());
            int rank = 1;
            int n[] = {(int)N_X};
            int *inembed = n;
            int *onembed = n;
            int istride = N_Y*N_Z;
            int ostride = istride;
            int idist = 1;
            int odist = idist;
            int howmany = N_Y*N_Z;
            fft_plan = fftw_plan_many_dft(
                    rank,
                    n,
                    howmany,
                    reinterpret_cast<fftw_complex *>(data.data()),
                    inembed,
                    istride,
                    idist,
                    reinterpret_cast<fftw_complex *>(result.data()),
                    onembed,
                    ostride,
                    odist,
                    FFTW_FORWARD,
                    FFTW_MEASURE);
        }
    
        const vector<complex<float_type>> &getResult() const {
            return result;
        }
    
        vector<complex<float_type>>& run(){
            fftw_execute(fft_plan);
            return result;
        }
    
    };
    
    struct z_fft {
        vector<complex<float_type>>& data;
        vector<complex<float_type>> result;
        fftw_plan fft_plan;
        index_type N_X;
        index_type N_Y;
        index_type N_Z;
    
    
        z_fft(vector<complex<float_type>>& data,index_type N_X,index_type N_Y,index_type N_Z)
                : data(data), N_X(N_X), N_Y(N_Y), N_Z(N_Z)
        {
            result = vector<complex<float_type>>(data.size());
            int rank = 1;
            int n[] = {(int)N_Z};
            int *inembed = n;
            int *onembed = n;
            int istride = 1;
            int ostride = istride;
            int idist = N_Z;
            int odist = idist;
            int howmany = N_X*N_Y;
            fft_plan = fftw_plan_many_dft(
                    rank,
                    n,
                    howmany,
                    reinterpret_cast<fftw_complex *>(data.data()),
                    inembed,
                    istride,
                    idist,
                    reinterpret_cast<fftw_complex *>(result.data()),
                    onembed,
                    ostride,
                    odist,
                    FFTW_FORWARD,
                    FFTW_MEASURE);
        }
    
        vector<complex<float_type>>& run(){
            fftw_execute(fft_plan);
            return result;
        }
    
    };
    
    
    struct xz_fft_many {
        vector<complex<float_type>>& data;
        vector<complex<float_type>> result;
        fftw_plan fft_plan;
        index_type N_X;
        index_type N_Y;
        index_type N_Z;
    
    
        xz_fft_many(vector<complex<float_type>>& data,index_type N_X,index_type N_Y,index_type N_Z)
                : data(data), N_X(N_X), N_Y(N_Y), N_Z(N_Z)
        {
            result = vector<complex<float_type>>(data.size());
            int rank = 2;
            int n[] = {(int) N_X, (int) N_Z};
            int inembed[] = {(int) N_X, (int) (N_Z*N_Y)};
            int *onembed = inembed;
            int istride = 1;
            int ostride = 1;
            int idist = N_Z;
            int odist = N_Z;
            int howmany = N_Y;
            fft_plan = fftw_plan_many_dft(
                    rank,
                    n,
                    howmany,
                    reinterpret_cast<fftw_complex *>(data.data()),
                    inembed,
                    istride,
                    idist,
                    reinterpret_cast<fftw_complex *>(result.data()),
                    onembed,
                    ostride,
                    odist,
                    FFTW_FORWARD,FFTW_MEASURE);
        }
    
        vector<complex<float_type>>& run(){
            fftw_execute(fft_plan);
            return result;
        }
    
    };
    
    struct xz_fft_composition {
        vector<complex<float_type>>& data;
        index_type N_X;
        index_type N_Y;
        index_type N_Z;
        x_fft* xFft;
        z_fft* zFft;
    
    
        xz_fft_composition(vector<complex<float_type>>& data,index_type N_X,index_type N_Y,index_type N_Z)
                : data(data), N_X(N_X), N_Y(N_Y), N_Z(N_Z)
        {
            xFft = new x_fft(data,N_X,N_Y,N_Z);
            zFft = new z_fft(xFft->result,N_X,N_Y,N_Z);
        }
    
        vector<complex<float_type>>& run(){
            xFft->run();
            return zFft->run();
        }
    
    };
    
    struct TestData{
        index_type N_X = 512;
        index_type N_Y = 16;
        index_type N_Z = 16;
    
        index_type ARRAY_SIZE = N_X * N_Y * N_Z;
        std::vector<complex<float_type>> data = get_data(ARRAY_SIZE);
    
        TestData() {
    //        print(data,N_X,N_Y,N_Z);
        }
    };
    
    TestData testData;
    
    struct SanityTest{
        SanityTest() {
            xz_fft_many fft_many(testData.data, testData.N_X, testData.N_Y, testData.N_Z);
            xz_fft_composition fft_composition(testData.data, testData.N_X, testData.N_Y, testData.N_Z);
            std::vector<complex<float_type>> fft_many_result =  fft_many.run();
            std::vector<complex<float_type>> fft_composition_result =  fft_composition.run();
    
            bool equal = std::equal(fft_composition_result.begin(), fft_composition_result.end(), fft_many_result.begin());
            assert(equal);
            if(equal){
                cout << "ok" << endl;
            }
        }
    };
    
    SanityTest sanityTest;
    
    static void XZ_test_many(benchmark::State& state) {
        xz_fft_many fft(testData.data, testData.N_X, testData.N_Y, testData.N_Z);
        for (auto _ : state) {
            auto result = fft.run();
        }
    }
    
    static void XZ_test_composition(benchmark::State& state) {
        xz_fft_composition fft(testData.data, testData.N_X, testData.N_Y, testData.N_Z);
        for (auto _ : state) {
            auto result = fft.run();
        }
    }
    
    BENCHMARK(XZ_test_many)->Iterations(1000);
    BENCHMARK(XZ_test_composition)->Iterations(1000);
    
    BENCHMARK_MAIN();
    

    If I done benchmarks correctly there are some significant differences beetwen fftw_plan_many_dft and composition approaches for different N_X, N_Y, N_Z combinations. For example using

        index_type N_X = 512;
        index_type N_Y = 16;
        index_type N_Z = 16;
    

    I've got almost two times difference in favour of fftw_plan_many_dft but for other sets of input parameters I've often found composition aproach to be faster but not that much.

    ------------------------------------------------------------------------------
    Benchmark                                    Time             CPU   Iterations
    ------------------------------------------------------------------------------
    XZ_test_many/iterations:1000           1412647 ns      1364813 ns         1000
    XZ_test_composition/iterations:1000    2619807 ns      2542472 ns         1000