Search code examples
pythonalgorithmnumpyfftmultiplication

realize FFT and IFFT using python3


When I multiply two big integers using FFT, I find the result of FFT and IFFT is always not right.

method

To realize FFT, I just follow the pseudocode as followed: the pseudocode of FFT

The equations of FFT and IFFT are as followed. So, when realizing IFFT, I just replace a with y, replace omega with omega ^^ -1 and divide it by n. And, use flag to distinguish them in my function.

  • For FFT, y will beenter image description here

  • For IFFT, a will be enter image description here

problem

To find the problem, I try to compare the results between numpy.fft and my function.

  1. FFT. The results of numpy and my function look the same, but the sign of images is the opposite. For example (the second element of case2 below):
    • my function result: -4-9.65685424949238j
    • numpy result: -4+9.65685424949238j
  2. IFFT. I just find it wrong, and can't find any rule.

python code

Here is my function FFT, and comparison:

from typing import List
from cmath import pi, exp
from numpy.fft import fft, ifft


def FFT(a: List, flag: bool) -> List:
    """realize DFT using FFT"""
    n = len(a)
    if n == 1:
        return a

    # complex root
    omg_n = exp(2 * pi * 1j / n)
    if flag:
        # IFFT
        omg_n = 1 / omg_n
    omg = 1

    # split a into 2 part
    a0 = a[::2]  # even
    a1 = a[1::2]  # odd

    # corresponding y
    y0 = FFT(a0, flag)
    y1 = FFT(a1, flag)

    # result y
    y = [0] * n
    for k in range(n // 2):
        y[k] = y0[k] + omg * y1[k]
        y[k + n // 2] = y0[k] - omg * y1[k]
        omg = omg * omg_n

    # IFFT
    if flag:
        y = [i / n for i in y]
    return y

if __name__ == '__main__':
    test_cases = [
        [1, 1],
        [1, 2, 3, 4, 5, 6, 7, 8],
        [1, 4, 2, 9, 0, 0, 3, 8, 9, 1, 4, 0, 0, 0, 0, 0, ],
    ]

    print("test FFT")
    for i, case in enumerate(test_cases):
        print(f"case{i + 1}", case)
        manual_result = FFT(case, False)
        numpy_result = fft(case).tolist()
        print("manual_result:", manual_result)
        print("numpy_result:", numpy_result)
        print("difference:", [i - j for i, j in zip(manual_result, numpy_result)])
        print()



    print("test IFFT")
    for i, case in enumerate(test_cases):
        print(f"case{i + 1}", case)
        manual_result = FFT(case, True)
        numpy_result = ifft(case).tolist()
        print("manual_result:", manual_result)
        print("numpy_result:", numpy_result)
        print("difference:", [i - j for i, j in zip(manual_result, numpy_result)])
        print()

The FFT output:

test FFT
case1 [1, 1]
manual_result: [2, 0]
numpy_result: [(2+0j), 0j]
difference: [0j, 0j]

case2 [1, 2, 3, 4, 5, 6, 7, 8]
manual_result: [36, (-4-9.65685424949238j), (-4-4.000000000000001j), (-4-1.6568542494923815j), -4, (-4+1.6568542494923806j), (-4+4.000000000000001j), (-3.999999999999999+9.656854249492381j)]
numpy_result: [(36+0j), (-4+9.65685424949238j), (-4+4j), (-4+1.6568542494923806j), (-4+0j), (-4-1.6568542494923806j), (-4-4j), (-4-9.65685424949238j)]
difference: [0j, -19.31370849898476j, -8j, -3.313708498984762j, 0j, 3.313708498984761j, 8j, (8.881784197001252e-16+19.31370849898476j)]

case3 [1, 4, 2, 9, 0, 0, 3, 8, 9, 1, 4, 0, 0, 0, 0, 0]
manual_result: [41, (-12.710780677203363+13.231540329804117j), (12.82842712474619+7.2426406871192865j), (-14.692799048494296+7.4256307475248935j), (1.0000000000000013-12j), (5.763866860359768+6.0114171851517995j), (7.171572875253808+1.2426406871192839j), (-10.360287134662114+11.817326767431025j), -3, (-10.360287134662112-11.817326767431021j), (7.17157287525381-1.2426406871192848j), (5.763866860359771-6.011417185151798j), (0.9999999999999987+12j), (-14.692799048494292-7.425630747524895j), (12.828427124746192-7.242640687119286j), (-12.710780677203362-13.23154032980412j)]
numpy_result: [(41+0j), (-12.710780677203363-13.231540329804115j), (12.82842712474619-7.242640687119286j), (-14.692799048494292-7.4256307475248935j), (1+12j), (5.763866860359768-6.011417185151798j), (7.17157287525381-1.2426406871192857j), (-10.360287134662112-11.81732676743102j), (-3+0j), (-10.360287134662112+11.81732676743102j), (7.17157287525381+1.2426406871192857j), (5.763866860359768+6.011417185151798j), (1-12j), (-14.692799048494292+7.4256307475248935j), (12.82842712474619+7.242640687119286j), (-12.710780677203363+13.231540329804115j)]
difference: [0j, 26.46308065960823j, 14.485281374238571j, (-3.552713678800501e-15+14.851261495049787j), (1.3322676295501878e-15-24j), 12.022834370303597j, (-1.7763568394002505e-15+2.4852813742385695j), (-1.7763568394002505e-15+23.634653534862046j), 0j, -23.63465353486204j, -2.4852813742385704j, (3.552713678800501e-15-12.022834370303595j), (-1.3322676295501878e-15+24j), -14.851261495049789j, (1.7763568394002505e-15-14.485281374238571j), (1.7763568394002505e-15-26.463080659608238j)]

The IFFT result:

test IFFT
case1 [1, 1]
manual_result: [1.0, 0.0]
numpy_result: [(1+0j), 0j]
difference: [0j, 0j]

case2 [1, 2, 3, 4, 5, 6, 7, 8]
manual_result: [0.5625, (-0.0625+0.15088834764831843j), (-0.0625+0.062499999999999986j), (-0.0625+0.025888347648318405j), -0.0625, (-0.0625-0.025888347648318433j), (-0.0625-0.062499999999999986j), (-0.062499999999999986-0.1508883476483184j)]
numpy_result: [(4.5+0j), (-0.5-1.2071067811865475j), (-0.5-0.5j), (-0.5-0.20710678118654757j), (-0.5+0j), (-0.5+0.20710678118654757j), (-0.5+0.5j), (-0.5+1.2071067811865475j)]
difference: [(-3.9375+0j), (0.4375+1.357995128834866j), (0.4375+0.5625j), (0.4375+0.23299512883486598j), (0.4375+0j), (0.4375-0.232995128834866j), (0.4375-0.5625j), (0.4375-1.357995128834866j)]

case3 [1, 4, 2, 9, 0, 0, 3, 8, 9, 1, 4, 0, 0, 0, 0, 0]
manual_result: [0.0400390625, (-0.01241287175508141-0.012921426103324331j), (0.012527760864009951-0.007072891296014926j), (-0.014348436570795205-0.007251592526879778j), (0.0009765625000000013+0.01171875j), (0.005628776230820083-0.005870524594874804j), (0.007003489135990047-0.0012135162960149274j), (-0.01011746790494347-0.011540358171319353j), -0.0029296875, (-0.010117467904943469+0.011540358171319355j), (0.007003489135990049+0.0012135162960149274j), (0.005628776230820081+0.005870524594874803j), (0.0009765624999999987-0.01171875j), (-0.014348436570795205+0.0072515925268797805j), (0.012527760864009953+0.007072891296014926j), (-0.012412871755081408+0.01292142610332433j)]
numpy_result: [(2.5625+0j), (-0.7944237923252102+0.8269712706127572j), (0.8017766952966369+0.45266504294495535j), (-0.9182999405308933+0.46410192172030584j), (0.0625-0.75j), (0.3602416787724855+0.37571357407198736j), (0.44822330470336313+0.07766504294495535j), (-0.647517945916382+0.7385829229644387j), (-0.1875+0j), (-0.647517945916382-0.7385829229644387j), (0.44822330470336313-0.07766504294495535j), (0.3602416787724855-0.37571357407198736j), (0.0625+0.75j), (-0.9182999405308933-0.46410192172030584j), (0.8017766952966369-0.45266504294495535j), (-0.7944237923252102-0.8269712706127572j)]
difference: [(-2.5224609375+0j), (0.7820109205701288-0.8398926967160816j), (-0.7892489344326269-0.45973793424097026j), (0.903951503960098-0.47135351424718563j), (-0.0615234375+0.76171875j), (-0.3546129025416654-0.38158409866686216j), (-0.4412198155673731-0.07887855924097029j), (0.6374004780114385-0.7501232811357581j), (0.1845703125+0j), (0.6374004780114385+0.7501232811357581j), (-0.4412198155673731+0.07887855924097029j), (-0.3546129025416654+0.38158409866686216j), (-0.0615234375-0.76171875j), (0.903951503960098+0.47135351424718563j), (-0.7892489344326269+0.45973793424097026j), (0.7820109205701288+0.8398926967160816j)]

@pjs, Thank you for your reminder that FFT requires len(data) to be a power of 2.


Solution

  • As was pointed out in comments, you used a positive sign in the computation of omg_n. There are different definitions of the DFT, so it isn't wrong by itself. However this would naturally lead to differences if you compare your results with an implementation that uses a negative sign, as is the case with numpy.fft.fft. Adjusting your implementation to also use a negative sign would cover all forward transform cases (leaving only small roundoff errors on the order of ~10-16).

    For the inverse transform cases, your implementation ends up scaling the result by 1/n at every stage, instead of only the final stage. To correct this, simply remove the scaling from the recursion, and normalize only on the final stage:

    def FFTrecursion(a: List, flag: bool) -> List:
        """Recursion of the FFT implementation"""
    
        n = len(a)
        if n == 1:
            return a
    
        # complex root
        omg_n = exp(-2 * pi * 1j / n)
        if flag:
            # IFFT
            omg_n = 1 / omg_n
        omg = 1
    
        # split a into 2 part
        a0 = a[::2]  # even
        a1 = a[1::2]  # odd
    
        # corresponding y
        y0 = FFTrecursion(a0, flag)
        y1 = FFTrecursion(a1, flag)
    
        # result y
        y = [0] * n
        for k in range(n // 2):
            y[k] = y0[k] + omg * y1[k]
            y[k + n // 2] = y0[k] - omg * y1[k]
            omg = omg * omg_n
    
        return y
    
    
    def FFT(a: List, flag: bool) -> List:
        """realize DFT using FFT"""
    
        y = FFTrecursion(a, flag)
    
        # IFFT final scaling
        if flag:
            n = len(a)
            y = [i / n for i in y]
        return y