Search code examples
pythonpython-3.xsignal-processingfftpytorch

Implementing FFT with Pytorch


I am trying to implement FFT by using the conv1d function provided in Pytorch.

Generating artifical signal

import numpy as np
import torch
from torch.autograd import Variable
from torch.nn.functional import conv1d

from scipy import fft, fftpack

import matplotlib.pyplot as plt

%matplotlib inline

# Creating filters

d = 4096 # size of windows

def create_filters(d):
    x = np.arange(0, d, 1)
    wsin = np.empty((d,1,d), dtype=np.float32)
    wcos = np.empty((d,1,d), dtype=np.float32)
    window_mask = 1.0-1.0*np.cos(x)
    for ind in range(d):
        wsin[ind,0,:] = np.sin(2*np.pi*((ind+1)/d)*x)
        wcos[ind,0,:] = np.cos(2*np.pi*((ind+1)/d)*x)

    return wsin,wcos

wsin, wcos = create_filters(d)
wsin_var = Variable(torch.from_numpy(wsin), requires_grad=False)
wcos_var = Variable(torch.from_numpy(wcos),requires_grad=False)

# Creating signal

t = np.linspace(0,1,4096)
x = np.sin(2*np.pi*100*t)+np.sin(2*np.pi*200*t)+np.random.normal(scale=5,size=(4096))

plt.plot(x) 

enter image description here

FFT with Pytorch

signal_input = torch.from_numpy(x.reshape(1,-1),)[:,None,:4096]

signal_input = signal_input.float()

zx = conv1d(signal_input, wsin_var, stride=1).pow(2)+conv1d(signal_input, wcos_var, stride=1).pow(2)

enter image description here

FFT with Scipy

fig = plt.figure(figsize=(20,5))
plt.plot(np.abs(fft(x).reshape(-1))[:500])

My Question

As you can see the two outputs are quite similar in terms of the peaks characteristics. That means my implementation is not totally wrong. However, there are also some subtleties, such as the scale of the spectrum, and the signal to noise ratio. I am unable to figure out what's missing here to get the exact same result.

enter image description here


Solution

  • You calculated the power rather than the amplitude. You simply need to add the line zx = zx.pow(0.5) to take the square root to get the amplitude.