Search code examples
numpymatplotlibfftdftcmath

fft algorithm yields imprecise results


I'm trying to implement the fft (fast fourier transform) based on dft(discrete fourier transform) matrix factorization.In the following code, both fft and the straightforward method(i.e.: multiply the dft matrix directly with v) are implemented in order to test the validity of my implementation of fft.

import numpy as n
import cmath, math
import matplotlib.pyplot as plt

v=n.array([1,-1,2,-3])
w=v
N=len(v)
t=[0]*N
M=n.zeros((N,N),dtype=complex)
z=n.exp(2j*math.pi/N)
for a in range(N):
    for b in range(N):
        M[a][b]=n.exp(2j*math.pi*a*b/N)
print (n.dot(v,M))
plt.plot(n.dot(v,M))
def f(x):
    x=n.concatenate([x[::2],x[1::2]])
    return x

while (w!=f(v)).any():
    v=f(v)
print(v)
a=2
while a<=N:

    for k in range(N/a):
        for y in range(a/2):
            t[y]=v[a*k+y]
        for i in range(a/2):
            v[a*k+i]+=v[a*k+i+a/2]*(z**i)
            v[a*k+i+a/2]=t[i]-v[a*k+i+a/2]*(z**i)
    a*=2    
print(v)
plt.plot(v)

plt.show()

I've tried this with lots of values of v, sometimes the outputs of these two methods yield exactly the same result but other times they are close to each other but not exactly the same. They haven't gone far away from each other yet after a few tests each with a different value of v.

Is there anything that I'm missing that causes the imprecision of the code?

EDIT: Please note that the code is designed for Python 2 (because of the implicit integer divisions).


Solution

  • It seems that the problem is not in the algorithm, but in the declaration of v (thanks @kazemakase). Try

    v=n.array([1,-1,2,-3], dtype=complex) 
    

    instead. At least for me the curves then appear on top of each other:

    enter image description here

    EDIT

    This was quite the journey. I wasn't able to figure out what's wrong with your code, but it looks like there are several errors, both with the dft and the fft. In the end I wrote my own version of the fft based on [this document] (http://www.cs.cmu.edu/afs/andrew/scs/cs/15-463/2001/pub/www/notes/fourier/fourier.pdf) (pages 6 -- 9 hold all the information you need). Maybe you can go through the algorithm and figure out where your problems lie. The algorithm for the bit reversal can be found in this answer (or alternatively in this one ). I tested the code for linear vectors of different lengths -- let me know if you find any mistakes.

    import numpy as np
    import cmath
    
    def bit_reverse(x,n):
        """
        Reverse the last n bits of x
        """
    
        ##from https://stackoverflow.com/a/12682003/2454357
        ##formstr = '{{:0{}b}}'.format(n)
        ##return int(formstr.format(x)[::-1],2)
    
        ##from https://stackoverflow.com/a/5333563/2454357
        return sum(1<<(n-1-i) for i in range(n) if x>>i&1)
    
    def permute_vector(v):
        """
        Permute vector v such that the indices of the result
        correspond to the bit-reversed indices of the original.
        Returns the permuted input vector and the number of bits used.
        """
        ##check that len(v) == 2**n
        ##and at the same time find permutation length:
        L = len(v)
        comp = 1
        bits = 0
        while comp<L:
            comp *= 2
            bits += 1
        if comp != L:
            raise ValueError('permute_vector: wrong length of v -- must be 2**n')
        rindices = [bit_reverse(i,bits)for i in range(L)]
        return v[rindices],bits
    
    def dft(v):
        N = v.shape[0]
        a,b = np.meshgrid(
            np.linspace(0,N-1,N,dtype=np.complex128),
            np.linspace(0,N-1,N,dtype=np.complex128),
        )
        M = np.exp((-2j*np.pi*a*b)/N)
    
        return np.dot(M,v)
    
    
    def fft(v):
        w,bits = permute_vector(v)
        N = w.shape[0]
        z=np.exp(np.array(-2j,dtype=np.complex128)*np.pi/N)
    
        ##starting fft
        for i in range(bits): 
            dist = 2**i  ##distance between 'exchange pairs'
            group = dist*2 ##size of sub-groups
            for start in range(0,N,group):
                for offset in range(group//2):
                    pos1 = start+offset
                    pos2 = pos1+dist
                    alpha1 = z**((pos1*N//group)%N)
                    alpha2 = z**((pos2*N//group)%N)
                    w[pos1],w[pos2] = w[pos1]+alpha1*w[pos2],w[pos1]+alpha2*w[pos2]
        return w
    
    if __name__ == '__main__':
    
        #test the fft
        for n in [2**i for i in range(1,5)]:
            print('-'*25+'n={}'.format(n)+'-'*25)
            v = np.linspace(0,n-1,n, dtype=np.complex128)
            print('v = ')
            print(v)
            print('fft(v) = ')
            print(fft(v))
            print('dft(v) = ')
            print(dft(v))
            print('relative error:')
            print(abs(fft(v)-dft(v))/abs(dft(v)))
    

    This gives the following output:

    -------------------------n=2-------------------------
    v = 
    [ 0.+0.j  1.+0.j]
    fft(v) = 
    [ 1. +0.00000000e+00j -1. -1.22464680e-16j]
    dft(v) = 
    [ 1. +0.00000000e+00j -1. -1.22464680e-16j]
    relative error:
    [ 0.  0.]
    -------------------------n=4-------------------------
    v = 
    [ 0.+0.j  1.+0.j  2.+0.j  3.+0.j]
    fft(v) = 
    [ 6. +0.00000000e+00j -2. +2.00000000e+00j -2. -4.89858720e-16j
     -2. -2.00000000e+00j]
    dft(v) = 
    [ 6. +0.00000000e+00j -2. +2.00000000e+00j -2. -7.34788079e-16j
     -2. -2.00000000e+00j]
    relative error:
    [  0.00000000e+00   0.00000000e+00   1.22464680e-16   3.51083347e-16]
    -------------------------n=8-------------------------
    v = 
    [ 0.+0.j  1.+0.j  2.+0.j  3.+0.j  4.+0.j  5.+0.j  6.+0.j  7.+0.j]
    fft(v) = 
    [ 28. +0.00000000e+00j  -4. +9.65685425e+00j  -4. +4.00000000e+00j
      -4. +1.65685425e+00j  -4. -7.10542736e-15j  -4. -1.65685425e+00j
      -4. -4.00000000e+00j  -4. -9.65685425e+00j]
    dft(v) = 
    [ 28. +0.00000000e+00j  -4. +9.65685425e+00j  -4. +4.00000000e+00j
      -4. +1.65685425e+00j  -4. -3.42901104e-15j  -4. -1.65685425e+00j
      -4. -4.00000000e+00j  -4. -9.65685425e+00j]
    relative error:
    [  0.00000000e+00   6.79782332e-16   7.40611132e-16   1.85764404e-15
       9.19104080e-16   3.48892999e-15   3.92837008e-15   1.35490975e-15]
    -------------------------n=16-------------------------
    v = 
    [  0.+0.j   1.+0.j   2.+0.j   3.+0.j   4.+0.j   5.+0.j   6.+0.j   7.+0.j
       8.+0.j   9.+0.j  10.+0.j  11.+0.j  12.+0.j  13.+0.j  14.+0.j  15.+0.j]
    fft(v) = 
    [ 120. +0.00000000e+00j   -8. +4.02187159e+01j   -8. +1.93137085e+01j
       -8. +1.19728461e+01j   -8. +8.00000000e+00j   -8. +5.34542910e+00j
       -8. +3.31370850e+00j   -8. +1.59129894e+00j   -8. +2.84217094e-14j
       -8. -1.59129894e+00j   -8. -3.31370850e+00j   -8. -5.34542910e+00j
       -8. -8.00000000e+00j   -8. -1.19728461e+01j   -8. -1.93137085e+01j
       -8. -4.02187159e+01j]
    dft(v) = 
    [ 120. +0.00000000e+00j   -8. +4.02187159e+01j   -8. +1.93137085e+01j
       -8. +1.19728461e+01j   -8. +8.00000000e+00j   -8. +5.34542910e+00j
       -8. +3.31370850e+00j   -8. +1.59129894e+00j   -8. -6.08810394e-14j
       -8. -1.59129894e+00j   -8. -3.31370850e+00j   -8. -5.34542910e+00j
       -8. -8.00000000e+00j   -8. -1.19728461e+01j   -8. -1.93137085e+01j
       -8. -4.02187159e+01j]
    relative error:
    [  0.00000000e+00   1.09588741e-15   1.45449990e-15   6.36716793e-15
       8.53211992e-15   9.06818284e-15   1.30922044e-14   5.40949529e-15
       1.11628436e-14   1.23698141e-14   1.50430426e-14   3.02428869e-14
       2.84810617e-14   1.16373983e-14   1.10680934e-14   3.92841628e-15]
    

    This was quite a nice challenge -- I learned a lot! You can verify the results of the code online, for instance here.