Search code examples
pytorcharm

How can I emulate/run PyTorch model that uses ATen STFT implementation on arm based cpu?


I am trying to run my PyTorch model for ASR on an arm based device without gpu. As far as I know, arm does not support MKL which ATen uses. Naturally, I am getting the following error when I try to make inference:

RuntimeError: fft: ATen not compiled with MKL support

How can I solve this problem? Are there any alternatives that I can use?


Solution

  • I solved this issue by bypassing PyTorch's stft implementation. This may not be feasible for everyone, but in my case it allowed me to make predictions using my model with no issues on arm device.

    The problem stemmed from _VF.stft call in packages/torch/functional.py.

    I changed the line

    return _VF.stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex) 
    

    with:

    librosa_stft = librosa.stft(input.cpu().detach().numpy().reshape(-1), n_fft, hop_length, win_length, window="hann", center=True, pad_mode=pad_mode)
    librosa_stft = np.array([[a.real, a.imag] for a in librosa_stft])
    librosa_stft = np.transpose(librosa_stft, axes=[0, 2, 1])
    librosa_stft = np.expand_dims(librosa_stft, 0)
    librosa_stft = torch.from_numpy(librosa_stft)
    return librosa_stft
    

    This code may be optimized further. I just tried to replicate what PyTorch did by using Librosa. Resulting output is same in both versions in my case. But you should check your outputs to be sure if you decide to use this method.