When I multiply two big integers using FFT, I find the result of FFT and IFFT is always not right.
To realize FFT, I just follow the pseudocode as followed: the pseudocode of FFT
The equations of FFT and IFFT are as followed. So, when realizing IFFT, I just replace a
with y
, replace omega
with omega ^^ -1
and divide it by n
. And, use flag
to distinguish them in my function.
To find the problem, I try to compare the results between numpy.fft and my function.
-4-9.65685424949238j
-4+9.65685424949238j
Here is my function FFT, and comparison:
from typing import List
from cmath import pi, exp
from numpy.fft import fft, ifft
def FFT(a: List, flag: bool) -> List:
"""realize DFT using FFT"""
n = len(a)
if n == 1:
return a
# complex root
omg_n = exp(2 * pi * 1j / n)
if flag:
# IFFT
omg_n = 1 / omg_n
omg = 1
# split a into 2 part
a0 = a[::2] # even
a1 = a[1::2] # odd
# corresponding y
y0 = FFT(a0, flag)
y1 = FFT(a1, flag)
# result y
y = [0] * n
for k in range(n // 2):
y[k] = y0[k] + omg * y1[k]
y[k + n // 2] = y0[k] - omg * y1[k]
omg = omg * omg_n
# IFFT
if flag:
y = [i / n for i in y]
return y
if __name__ == '__main__':
test_cases = [
[1, 1],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 4, 2, 9, 0, 0, 3, 8, 9, 1, 4, 0, 0, 0, 0, 0, ],
]
print("test FFT")
for i, case in enumerate(test_cases):
print(f"case{i + 1}", case)
manual_result = FFT(case, False)
numpy_result = fft(case).tolist()
print("manual_result:", manual_result)
print("numpy_result:", numpy_result)
print("difference:", [i - j for i, j in zip(manual_result, numpy_result)])
print()
print("test IFFT")
for i, case in enumerate(test_cases):
print(f"case{i + 1}", case)
manual_result = FFT(case, True)
numpy_result = ifft(case).tolist()
print("manual_result:", manual_result)
print("numpy_result:", numpy_result)
print("difference:", [i - j for i, j in zip(manual_result, numpy_result)])
print()
The FFT output:
test FFT
case1 [1, 1]
manual_result: [2, 0]
numpy_result: [(2+0j), 0j]
difference: [0j, 0j]
case2 [1, 2, 3, 4, 5, 6, 7, 8]
manual_result: [36, (-4-9.65685424949238j), (-4-4.000000000000001j), (-4-1.6568542494923815j), -4, (-4+1.6568542494923806j), (-4+4.000000000000001j), (-3.999999999999999+9.656854249492381j)]
numpy_result: [(36+0j), (-4+9.65685424949238j), (-4+4j), (-4+1.6568542494923806j), (-4+0j), (-4-1.6568542494923806j), (-4-4j), (-4-9.65685424949238j)]
difference: [0j, -19.31370849898476j, -8j, -3.313708498984762j, 0j, 3.313708498984761j, 8j, (8.881784197001252e-16+19.31370849898476j)]
case3 [1, 4, 2, 9, 0, 0, 3, 8, 9, 1, 4, 0, 0, 0, 0, 0]
manual_result: [41, (-12.710780677203363+13.231540329804117j), (12.82842712474619+7.2426406871192865j), (-14.692799048494296+7.4256307475248935j), (1.0000000000000013-12j), (5.763866860359768+6.0114171851517995j), (7.171572875253808+1.2426406871192839j), (-10.360287134662114+11.817326767431025j), -3, (-10.360287134662112-11.817326767431021j), (7.17157287525381-1.2426406871192848j), (5.763866860359771-6.011417185151798j), (0.9999999999999987+12j), (-14.692799048494292-7.425630747524895j), (12.828427124746192-7.242640687119286j), (-12.710780677203362-13.23154032980412j)]
numpy_result: [(41+0j), (-12.710780677203363-13.231540329804115j), (12.82842712474619-7.242640687119286j), (-14.692799048494292-7.4256307475248935j), (1+12j), (5.763866860359768-6.011417185151798j), (7.17157287525381-1.2426406871192857j), (-10.360287134662112-11.81732676743102j), (-3+0j), (-10.360287134662112+11.81732676743102j), (7.17157287525381+1.2426406871192857j), (5.763866860359768+6.011417185151798j), (1-12j), (-14.692799048494292+7.4256307475248935j), (12.82842712474619+7.242640687119286j), (-12.710780677203363+13.231540329804115j)]
difference: [0j, 26.46308065960823j, 14.485281374238571j, (-3.552713678800501e-15+14.851261495049787j), (1.3322676295501878e-15-24j), 12.022834370303597j, (-1.7763568394002505e-15+2.4852813742385695j), (-1.7763568394002505e-15+23.634653534862046j), 0j, -23.63465353486204j, -2.4852813742385704j, (3.552713678800501e-15-12.022834370303595j), (-1.3322676295501878e-15+24j), -14.851261495049789j, (1.7763568394002505e-15-14.485281374238571j), (1.7763568394002505e-15-26.463080659608238j)]
The IFFT result:
test IFFT
case1 [1, 1]
manual_result: [1.0, 0.0]
numpy_result: [(1+0j), 0j]
difference: [0j, 0j]
case2 [1, 2, 3, 4, 5, 6, 7, 8]
manual_result: [0.5625, (-0.0625+0.15088834764831843j), (-0.0625+0.062499999999999986j), (-0.0625+0.025888347648318405j), -0.0625, (-0.0625-0.025888347648318433j), (-0.0625-0.062499999999999986j), (-0.062499999999999986-0.1508883476483184j)]
numpy_result: [(4.5+0j), (-0.5-1.2071067811865475j), (-0.5-0.5j), (-0.5-0.20710678118654757j), (-0.5+0j), (-0.5+0.20710678118654757j), (-0.5+0.5j), (-0.5+1.2071067811865475j)]
difference: [(-3.9375+0j), (0.4375+1.357995128834866j), (0.4375+0.5625j), (0.4375+0.23299512883486598j), (0.4375+0j), (0.4375-0.232995128834866j), (0.4375-0.5625j), (0.4375-1.357995128834866j)]
case3 [1, 4, 2, 9, 0, 0, 3, 8, 9, 1, 4, 0, 0, 0, 0, 0]
manual_result: [0.0400390625, (-0.01241287175508141-0.012921426103324331j), (0.012527760864009951-0.007072891296014926j), (-0.014348436570795205-0.007251592526879778j), (0.0009765625000000013+0.01171875j), (0.005628776230820083-0.005870524594874804j), (0.007003489135990047-0.0012135162960149274j), (-0.01011746790494347-0.011540358171319353j), -0.0029296875, (-0.010117467904943469+0.011540358171319355j), (0.007003489135990049+0.0012135162960149274j), (0.005628776230820081+0.005870524594874803j), (0.0009765624999999987-0.01171875j), (-0.014348436570795205+0.0072515925268797805j), (0.012527760864009953+0.007072891296014926j), (-0.012412871755081408+0.01292142610332433j)]
numpy_result: [(2.5625+0j), (-0.7944237923252102+0.8269712706127572j), (0.8017766952966369+0.45266504294495535j), (-0.9182999405308933+0.46410192172030584j), (0.0625-0.75j), (0.3602416787724855+0.37571357407198736j), (0.44822330470336313+0.07766504294495535j), (-0.647517945916382+0.7385829229644387j), (-0.1875+0j), (-0.647517945916382-0.7385829229644387j), (0.44822330470336313-0.07766504294495535j), (0.3602416787724855-0.37571357407198736j), (0.0625+0.75j), (-0.9182999405308933-0.46410192172030584j), (0.8017766952966369-0.45266504294495535j), (-0.7944237923252102-0.8269712706127572j)]
difference: [(-2.5224609375+0j), (0.7820109205701288-0.8398926967160816j), (-0.7892489344326269-0.45973793424097026j), (0.903951503960098-0.47135351424718563j), (-0.0615234375+0.76171875j), (-0.3546129025416654-0.38158409866686216j), (-0.4412198155673731-0.07887855924097029j), (0.6374004780114385-0.7501232811357581j), (0.1845703125+0j), (0.6374004780114385+0.7501232811357581j), (-0.4412198155673731+0.07887855924097029j), (-0.3546129025416654+0.38158409866686216j), (-0.0615234375-0.76171875j), (0.903951503960098+0.47135351424718563j), (-0.7892489344326269+0.45973793424097026j), (0.7820109205701288+0.8398926967160816j)]
@pjs, Thank you for your reminder that FFT requires len(data) to be a power of 2.
As was pointed out in comments, you used a positive sign in the computation of omg_n
. There are different definitions of the DFT, so it isn't wrong by itself. However this would naturally lead to differences if you compare your results with an implementation that uses a negative sign, as is the case with numpy.fft.fft
. Adjusting your implementation to also use a negative sign would cover all forward transform cases (leaving only small roundoff errors on the order of ~10-16).
For the inverse transform cases, your implementation ends up scaling the result by 1/n
at every stage, instead of only the final stage. To correct this, simply remove the scaling from the recursion, and normalize only on the final stage:
def FFTrecursion(a: List, flag: bool) -> List:
"""Recursion of the FFT implementation"""
n = len(a)
if n == 1:
return a
# complex root
omg_n = exp(-2 * pi * 1j / n)
if flag:
# IFFT
omg_n = 1 / omg_n
omg = 1
# split a into 2 part
a0 = a[::2] # even
a1 = a[1::2] # odd
# corresponding y
y0 = FFTrecursion(a0, flag)
y1 = FFTrecursion(a1, flag)
# result y
y = [0] * n
for k in range(n // 2):
y[k] = y0[k] + omg * y1[k]
y[k + n // 2] = y0[k] - omg * y1[k]
omg = omg * omg_n
return y
def FFT(a: List, flag: bool) -> List:
"""realize DFT using FFT"""
y = FFTrecursion(a, flag)
# IFFT final scaling
if flag:
n = len(a)
y = [i / n for i in y]
return y