Search code examples
openclfft

Higher radix (or better) formulation for Stockham FFT


Background

I've implemented this algorithm from Microsoft Research for a radix-2 FFT (Stockham auto sort) using OpenCL.

I use floating point textures (256 cols X N rows) for input and output in the kernel, because I will need to sample at non-integral points and I thought it better to delegate that to the texture sampling hardware. Note that my FFTs are always of 256-point sequences (every row in my texture). At this point, my N is 16384 or 32768 depending on the GPU I'm using and the max 2D texture size allowed.

I also need to perform the FFT of 4 real-valued sequences at once, so the kernel performs the FFT(a, b, c, d) as FFT(a + ib, c + id) from which I can extract the 4 complex sequences out later using an O(n) algorithm. I can elaborate on this if someone wishes - but I don't believe it falls in the scope of this question.

Kernel Source

const sampler_t fftSampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;

__kernel void FFT_Stockham(read_only image2d_t input, write_only image2d_t output, int fftSize, int size)
{
    int x = get_global_id(0);
    int y = get_global_id(1);
    int b = floor(x / convert_float(fftSize)) * (fftSize / 2);
    int offset = x % (fftSize / 2);
    int x0 = b + offset;
    int x1 = x0 + (size / 2);

    float4 val0 = read_imagef(input, fftSampler, (int2)(x0, y));
    float4 val1 = read_imagef(input, fftSampler, (int2)(x1, y));

    float angle = -6.283185f * (convert_float(x) / convert_float(fftSize));

    // TODO: Convert the two calculations below into lookups from a __constant buffer
    float tA = native_cos(angle);
    float tB = native_sin(angle);

    float4 coeffs1 = (float4)(tA, tB, tA, tB);
    float4 coeffs2 = (float4)(-tB, tA, -tB, tA);
    float4 result = val0 + coeffs1 * val1.xxzz + coeffs2 * val1.yyww;

    write_imagef(output, (int2)(x, y), result);
}

The host code simply invokes this kernel log2(256) times, ping-ponging the input and output textures.

Note: I tried removing the native_cos and native_sin to see if that impacted timing, but it doesn't seem to change things by very much. Not the factor I'm looking for, in any case.

Access pattern Knowing that I am probably memory-bandwidth bound, here is the memory access pattern (per-row) for my radix-2 FFT.

FFT Access pattern

  • X0 - element 1 to combine (read)
  • X1 - element 2 to combine (read)
  • X - element to write to (write)

Question

So my question is - can someone help me with/point me toward a higher-radix formulation for this algorithm? I ask because most FFTs are optimized for large cases and single real/complex valued sequences. Their kernel generators are also very case dependent and break down quickly when I try to muck with their internals.

Are there other options better than simply going to a radix-8 or 16 kernel?

Some of my constraints are - I have to use OpenCL (no cuFFT). I also cannot use clAmdFft from ACML for this purpose. It would be nice to also talk about CPU optimizations (this kernel SUCKS big time on the CPU) - but getting it to run in fewer iterations on the GPU is my main use-case.


Solution

  • I tried several versions, but the one with the best performance on CPU and GPU was a radix-16 kernel for my specific case.

    Here is the kernel for reference. It was taken from Eric Bainville's (most excellent) website and used with full attribution.

    // #define M_PI 3.14159265358979f
    
    //Global size is x.Length/2, Scale = 1 for direct, 1/N to inverse (iFFT)
    __kernel void ConjugateAndScale(__global float4* x, const float Scale)
    {
       int i = get_global_id(0);
    
       float temp = Scale;
       float4 t = (float4)(temp, -temp, temp, -temp);
    
       x[i] *= t;
    }
    
    
    // Return a*EXP(-I*PI*1/2) = a*(-I)
    float2 mul_p1q2(float2 a) { return (float2)(a.y,-a.x); }
    
    // Return a^2
    float2 sqr_1(float2 a)
    { return (float2)(a.x*a.x-a.y*a.y,2.0f*a.x*a.y); }
    
    // Return the 2x DFT2 of the four complex numbers in A
    // If A=(a,b,c,d) then return (a',b',c',d') where (a',c')=DFT2(a,c)
    // and (b',d')=DFT2(b,d).
    float8 dft2_4(float8 a) { return (float8)(a.lo+a.hi,a.lo-a.hi); }
    
    // Return the DFT of 4 complex numbers in A
    float8 dft4_4(float8 a)
    {
      // 2x DFT2
      float8 x = dft2_4(a);
      // Shuffle, twiddle, and 2x DFT2
      return dft2_4((float8)(x.lo.lo,x.hi.lo,x.lo.hi,mul_p1q2(x.hi.hi)));
    }
    
    // Complex product, multiply vectors of complex numbers
    
    #define MUL_RE(a,b) (a.even*b.even - a.odd*b.odd)
    #define MUL_IM(a,b) (a.even*b.odd + a.odd*b.even)
    
    float2 mul_1(float2 a, float2 b)
    { float2 x; x.even = MUL_RE(a,b); x.odd = MUL_IM(a,b); return x; }
    float4 mul_1_F4(float4 a, float4 b)
    { float4 x; x.even = MUL_RE(a,b); x.odd = MUL_IM(a,b); return x; }
    
    
    float4 mul_2(float4 a, float4 b)
    { float4 x; x.even = MUL_RE(a,b); x.odd = MUL_IM(a,b); return x; }
    
    // Return the DFT2 of the two complex numbers in vector A
    float4 dft2_2(float4 a) { return (float4)(a.lo+a.hi,a.lo-a.hi); }
    
    // Return cos(alpha)+I*sin(alpha)  (3 variants)
    float2 exp_alpha_1(float alpha)
    {
      float cs,sn;
      // sn = sincos(alpha,&cs);  // sincos
      //cs = native_cos(alpha); sn = native_sin(alpha);  // native sin+cos
      cs = cos(alpha); sn = sin(alpha); // sin+cos
      return (float2)(cs,sn);
    }
    // Return cos(alpha)+I*sin(alpha)  (3 variants)
    float4 exp_alpha_1_F4(float alpha)
    {
      float cs,sn;
      // sn = sincos(alpha,&cs);  // sincos
      // cs = native_cos(alpha); sn = native_sin(alpha);  // native sin+cos
      cs = cos(alpha); sn = sin(alpha); // sin+cos
      return (float4)(cs,sn,cs,sn);
    }
    
    
    // mul_p*q*(a) returns a*EXP(-I*PI*P/Q)
    #define mul_p0q1(a) (a)
    
    #define mul_p0q2 mul_p0q1
    //float2  mul_p1q2(float2 a) { return (float2)(a.y,-a.x); }
    
    __constant float SQRT_1_2 = 0.707106781186548; // cos(Pi/4)
    #define mul_p0q4 mul_p0q2
    float2  mul_p1q4(float2 a) { return (float2)(SQRT_1_2)*(float2)(a.x+a.y,-a.x+a.y); }
    #define mul_p2q4 mul_p1q2
    float2  mul_p3q4(float2 a) { return (float2)(SQRT_1_2)*(float2)(-a.x+a.y,-a.x-a.y); }
    
    __constant float COS_8 = 0.923879532511287; // cos(Pi/8)
    __constant float SIN_8 = 0.382683432365089; // sin(Pi/8)
    #define mul_p0q8 mul_p0q4
    float2  mul_p1q8(float2 a) { return mul_1((float2)(COS_8,-SIN_8),a); }
    #define mul_p2q8 mul_p1q4
    float2  mul_p3q8(float2 a) { return mul_1((float2)(SIN_8,-COS_8),a); }
    #define mul_p4q8 mul_p2q4
    float2  mul_p5q8(float2 a) { return mul_1((float2)(-SIN_8,-COS_8),a); }
    #define mul_p6q8 mul_p3q4
    float2  mul_p7q8(float2 a) { return mul_1((float2)(-COS_8,-SIN_8),a); }
    
    // Compute in-place DFT2 and twiddle
    #define DFT2_TWIDDLE(a,b,t) { float2 tmp = t(a-b); a += b; b = tmp; }
    
    // T = N/16 = number of threads.
    // P is the length of input sub-sequences, 1,16,256,...,N/16.
    __kernel void FFT_Radix16(__global const float4 * x, __global float4 * y, int pp)
    {
      int p = pp;
      int t = get_global_size(0); // number of threads
      int i = get_global_id(0); // current thread
    
    
    //////  y[i] = 2*x[i];
    //////  return;
    
      int k = i & (p-1); // index in input sequence, in 0..P-1
      // Inputs indices are I+{0,..,15}*T
      x += i;
      // Output indices are J+{0,..,15}*P, where
      // J is I with four 0 bits inserted at bit log2(P)
      y += ((i-k)<<4) + k;
    
      // Load
      float4 u[16];
      for (int m=0;m<16;m++) u[m] = x[m*t];
    
      // Twiddle, twiddling factors are exp(_I*PI*{0,..,15}*K/4P)
      float alpha = -M_PI*(float)k/(float)(8*p);
      for (int m=1;m<16;m++) u[m] = mul_1_F4(exp_alpha_1_F4(m * alpha), u[m]);
    
      // 8x in-place DFT2 and twiddle (1)
      DFT2_TWIDDLE(u[0].lo,u[8].lo,mul_p0q8);
      DFT2_TWIDDLE(u[0].hi,u[8].hi,mul_p0q8);
    
      DFT2_TWIDDLE(u[1].lo,u[9].lo,mul_p1q8);
      DFT2_TWIDDLE(u[1].hi,u[9].hi,mul_p1q8);
    
      DFT2_TWIDDLE(u[2].lo,u[10].lo,mul_p2q8);
      DFT2_TWIDDLE(u[2].hi,u[10].hi,mul_p2q8);
    
      DFT2_TWIDDLE(u[3].lo,u[11].lo,mul_p3q8);
      DFT2_TWIDDLE(u[3].hi,u[11].hi,mul_p3q8);
    
      DFT2_TWIDDLE(u[4].lo,u[12].lo,mul_p4q8);
      DFT2_TWIDDLE(u[4].hi,u[12].hi,mul_p4q8);
    
      DFT2_TWIDDLE(u[5].lo,u[13].lo,mul_p5q8);
      DFT2_TWIDDLE(u[5].hi,u[13].hi,mul_p5q8);
    
      DFT2_TWIDDLE(u[6].lo,u[14].lo,mul_p6q8);
      DFT2_TWIDDLE(u[6].hi,u[14].hi,mul_p6q8);
    
      DFT2_TWIDDLE(u[7].lo,u[15].lo,mul_p7q8);
      DFT2_TWIDDLE(u[7].hi,u[15].hi,mul_p7q8);
    
    
      // 8x in-place DFT2 and twiddle (2)
      DFT2_TWIDDLE(u[0].lo,u[4].lo,mul_p0q4);
      DFT2_TWIDDLE(u[0].hi,u[4].hi,mul_p0q4);
    
      DFT2_TWIDDLE(u[1].lo,u[5].lo,mul_p1q4);
      DFT2_TWIDDLE(u[1].hi,u[5].hi,mul_p1q4);
    
      DFT2_TWIDDLE(u[2].lo,u[6].lo,mul_p2q4);
      DFT2_TWIDDLE(u[2].hi,u[6].hi,mul_p2q4);
    
      DFT2_TWIDDLE(u[3].lo,u[7].lo,mul_p3q4);
      DFT2_TWIDDLE(u[3].hi,u[7].hi,mul_p3q4);
    
      DFT2_TWIDDLE(u[8].lo,u[12].lo,mul_p0q4);
      DFT2_TWIDDLE(u[8].hi,u[12].hi,mul_p0q4);
    
      DFT2_TWIDDLE(u[9].lo,u[13].lo,mul_p1q4);
      DFT2_TWIDDLE(u[9].hi,u[13].hi,mul_p1q4);
    
      DFT2_TWIDDLE(u[10].lo,u[14].lo,mul_p2q4);
      DFT2_TWIDDLE(u[10].hi,u[14].hi,mul_p2q4);
    
      DFT2_TWIDDLE(u[11].lo,u[15].lo,mul_p3q4);
      DFT2_TWIDDLE(u[11].hi,u[15].hi,mul_p3q4);
    
      // 8x in-place DFT2 and twiddle (3)
      DFT2_TWIDDLE(u[0].lo,u[2].lo,mul_p0q2);
      DFT2_TWIDDLE(u[0].hi,u[2].hi,mul_p0q2);
    
      DFT2_TWIDDLE(u[1].lo,u[3].lo,mul_p1q2);
      DFT2_TWIDDLE(u[1].hi,u[3].hi,mul_p1q2);
    
      DFT2_TWIDDLE(u[4].lo,u[6].lo,mul_p0q2);
      DFT2_TWIDDLE(u[4].hi,u[6].hi,mul_p0q2);
    
      DFT2_TWIDDLE(u[5].lo,u[7].lo,mul_p1q2);
      DFT2_TWIDDLE(u[5].hi,u[7].hi,mul_p1q2);
    
      DFT2_TWIDDLE(u[8].lo,u[10].lo,mul_p0q2);
      DFT2_TWIDDLE(u[8].hi,u[10].hi,mul_p0q2);
    
      DFT2_TWIDDLE(u[9].lo,u[11].lo,mul_p1q2);
      DFT2_TWIDDLE(u[9].hi,u[11].hi,mul_p1q2);
    
      DFT2_TWIDDLE(u[12].lo,u[14].lo,mul_p0q2);
      DFT2_TWIDDLE(u[12].hi,u[14].hi,mul_p0q2);
    
      DFT2_TWIDDLE(u[13].lo,u[15].lo,mul_p1q2);
      DFT2_TWIDDLE(u[13].hi,u[15].hi,mul_p1q2);
    
      // 8x DFT2 and store (reverse binary permutation)
      y[0]    = u[0]  + u[1];
      y[p]    = u[8]  + u[9];
      y[2*p]  = u[4]  + u[5];
      y[3*p]  = u[12] + u[13];
      y[4*p]  = u[2]  + u[3];
      y[5*p]  = u[10] + u[11];
      y[6*p]  = u[6]  + u[7];
      y[7*p]  = u[14] + u[15];
      y[8*p]  = u[0]  - u[1];
      y[9*p]  = u[8]  - u[9];
      y[10*p] = u[4]  - u[5];
      y[11*p] = u[12] - u[13];
      y[12*p] = u[2]  - u[3];
      y[13*p] = u[10] - u[11];
      y[14*p] = u[6]  - u[7];
      y[15*p] = u[14] - u[15];
    }
    

    Note that I have modified the kernel to perform the FFT of 2 complex-valued sequences at once instead of one. Also, since I only need the FFT of 256 elements at a time in a much larger sequence, I perform only 2 runs of this kernel, which leaves me with 256-length DFTs in the larger array.

    Here's some of the relevant host code as well.

    var ev = new[] { new Cl.Event() };
    var pEv = new[] { new Cl.Event() };
    
    int fftSize = 1;
    int iter = 0;
    int n = distributionSize >> 5;
    while (fftSize <= n)
    {
        Cl.SetKernelArg(fftKernel, 0, memA);
        Cl.SetKernelArg(fftKernel, 1, memB);
        Cl.SetKernelArg(fftKernel, 2, fftSize);
    
        Cl.EnqueueNDRangeKernel(commandQueue, fftKernel, 1, null, globalWorkgroupSize, localWorkgroupSize,
            (uint)(iter == 0 ? 0 : 1),
            iter == 0 ? null : pEv,
            out ev[0]).Check();
        if (iter > 0)
            pEv[0].Dispose();
        Swap(ref ev, ref pEv);
    
        Swap(ref memA, ref memB); // ping-pong
    
        fftSize = fftSize << 4;
        iter++;
    
        Cl.Finish(commandQueue);
    }
    
    Swap(ref memA, ref memB);
    

    Hope this helps someone!