Search code examples
pythonscipysignal-processingdigital-filter

How to implement multi-band-pass filter with scipy.signal.butter


Based on the band-pass filter here, I am trying to make a multi-band filter using the code bellow. However, the filtered signal is close to zero which affects the result when the spectrum is plotted. Should the coefficients of the filter of each band be normalized? Can you please someone suggest how I can fix the filter?

from scipy.signal import butter, sosfreqz, sosfilt
from scipy.signal import spectrogram
import matplotlib
import matplotlib.pyplot as plt
from scipy.fft import fft
import numpy as np


def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    sos = butter(order, [low, high], analog=False, btype='band', output='sos')
    return sos


def multiband_filter(data, bands, fs, order=10):
    sos_list = []
    for lowcut, highcut in bands:
        sos = butter_bandpass(lowcut, highcut, fs, order=order)
        scalar = max(abs(fft(sos, 2000)))
        # sos = sos / scalar
        sos_list += [sos]

    # sos_list = [butter_bandpass(lowcut, highcut, fs, order=order) for lowcut, highcut in bands]

    # Combine filters into a single filter
    sos = np.vstack(sos_list)

    # Apply the multiband filter to the data
    y = sosfilt(sos, data)

    return y, sos_list


def get_toy_signal():
    t = np.arange(0, 0.3, 1 / fs)

    fq = [-np.inf] + [x / 12 for x in range(-9, 3, 1)]

    mel = [5, 3, 1, 3, 5, 5, 5, 0, 3, 3, 3, 0, 5, 8, 8, 0, 5, 3, 1, 3, 5, 5, 5, 5, 3, 3, 5, 3, 1]
    acc = [5, 0, 8, 0, 5, 0, 5, 5, 3, 0, 3, 3, 5, 0, 8, 8, 5, 0, 8, 0, 5, 5, 5, 0, 3, 3, 5, 0, 1]

    toy_signal = np.array([])

    for kj in range(len(mel)):
        note_signal = np.sum([np.sin(2 * np.pi * 440 * 2 ** ff * t)
                              for ff in [fq[acc[kj]] - 1, fq[acc[kj]], fq[mel[kj]] + 1]], axis=0)

        zeros = np.zeros(int(0.01 * fs))
        toy_signal = np.concatenate((toy_signal, note_signal, zeros))

    toy_signal += np.random.normal(0, 1, len(toy_signal))

    toy_signal = toy_signal / (np.max(np.abs(toy_signal)) + 0.1)
    t_toy_signal = np.arange(len(toy_signal)) / fs

    return t_toy_signal, toy_signal


if __name__ == "__main__":

    fontsize = 12
    # Sample rate and desired cut_off frequencies (in Hz).
    fs = 3000

    f1, f2 = 100, 200
    f3, f4 = 470, 750
    f5, f6 = 800, 850
    f7, f8 = 1000, 1000.1
    cut_off = [(f1, f2), (f3, f4), (f5, f6), (f7, f8)]
    # cut_off = [(f1, f2), (f3, f4)]
    # cut_off = [(f1, f2)]
    # cut_off = [f1]

    t_toy_signal, toy_signal = get_toy_signal()
    # toy_signal -= np.mean(toy_signal)
    # t_toy_signal = wiener(t_toy_signal)

    fig, ax = plt.subplots(6, 1, figsize=(8, 12))
    fig.tight_layout()

    ax[0].plot(t_toy_signal, toy_signal)
    ax[0].set_title('Original toy_signal', fontsize=fontsize)
    ax[0].set_xlabel('Time (s)', fontsize=fontsize)
    ax[0].set_ylabel('Magnitude', fontsize=fontsize)
    ax[0].set_xlim(left=0, right=max(t_toy_signal))

    sos_list = [butter_bandpass(lowcut, highcut, fs, order=10) for lowcut, highcut in cut_off]

    # Combine filters into a single filter
    sos = np.vstack(sos_list)

    # w *= 0.5 * fs / np.pi  # Convert w to Hz.
    #####################################################################
    # First plot the desired ideal response as a green(ish) rectangle.
    #####################################################################

    # Plot the frequency response
    for i in range(len(cut_off)):
        w, h = sosfreqz(sos_list[i], worN=2000)
        ax[1].plot(0.5 * fs * w / np.pi, np.abs(h), label=f'Band {i + 1}: {cut_off[i]} Hz')

    ax[1].set_title('Multiband Filter Frequency Response')
    ax[1].set_xlabel('Frequency [Hz]')
    ax[1].set_ylabel('Gain')
    ax[1].legend()
    # ax[1].set_xlim(0, max(*cut_off) + 100)

    #####################################################################
    # Spectrogram of original signal
    #####################################################################

    f, t, Sxx = spectrogram(toy_signal, fs,
                            nperseg=930, noverlap=0)
    ax[2].pcolormesh(t, f, np.abs(Sxx),
                     norm=matplotlib.colors.LogNorm(vmin=np.min(Sxx), vmax=np.max(Sxx)),
                     )

    ax[2].set_title('Spectrogram of original toy_signal', fontsize=fontsize)
    ax[2].set_xlabel('Time (s)', fontsize=fontsize)
    ax[2].set_ylabel('Frequency (Hz)', fontsize=fontsize)

    #####################################################################
    # Compute filtered signal
    #####################################################################

    # Apply the multiband filter to the data
    # toy_signal_filtered = sosfilt(sos, toy_signal)
    toy_signal_filtered = np.sum([sosfilt(sos, toy_signal) for sos in sos_list], axis=0)

    #####################################################################
    # Spectrogram of filtered signal
    #####################################################################

    f, t, Sxx = spectrogram(toy_signal_filtered, fs,
                            nperseg=930, noverlap=0)

    ax[3].pcolormesh(t, f, np.abs(Sxx),
                     norm=matplotlib.colors.LogNorm(vmin=np.min(Sxx),
                                                    vmax=np.max(Sxx))
                     )

    ax[3].set_title('Spectrogram of filtered toy_signal', fontsize=fontsize)
    ax[3].set_xlabel('Time (s)', fontsize=fontsize)
    ax[3].set_ylabel('Frequency (Hz)', fontsize=fontsize)

    ax[4].plot(t_toy_signal, toy_signal_filtered)
    ax[4].set_title('Filtered toy_signal', fontsize=fontsize)
    ax[4].set_xlim(left=0, right=max(t_toy_signal))
    ax[4].set_xlabel('Time (s)', fontsize=fontsize)
    ax[4].set_ylabel('Magnitude', fontsize=fontsize)

    N = 1512
    X = fft(toy_signal, n=N)
    Y = fft(toy_signal_filtered, n=N)

    # fig.set_size_inches((10, 4))
    ax[5].plot(np.arange(N) / N * fs, 20 * np.log10(abs(X)), 'r-', label='FFT original signal')
    ax[5].plot(np.arange(N) / N * fs, 20 * np.log10(abs(Y)), 'g-', label='FFT filtered signal')
    ax[5].set_xlim(xmax=fs / 2)
    ax[5].set_ylim(ymin=-20)
    ax[5].set_ylabel(r'Power Spectrum (dB)', fontsize=fontsize)
    ax[5].set_xlabel("frequency (Hz)", fontsize=fontsize)
    ax[5].grid()
    ax[5].legend(loc='upper right')

    plt.tight_layout()
    plt.show()

    plt.figure()
    # fig.set_size_inches((10, 4))
    plt.plot(np.arange(N) / N * fs, 20 * np.log10(abs(X)), 'r-', label='FFT original signal')
    plt.plot(np.arange(N) / N * fs, 20 * np.log10(abs(Y)), 'g-', label='FFT filtered signal')
    plt.xlim(xmax=fs / 2)
    plt.ylim(ymin=-20)
    plt.ylabel(r'Power Spectrum (dB)', fontsize=fontsize)
    plt.xlabel("frequency (Hz)", fontsize=fontsize)
    plt.grid()
    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.show()

enter image description here

The following is after using @Warren Weckesser comment:

toy_signal_filtered = np.mean([sosfilt(sos, toy_signal) for sos in sos_list], axis=0)

enter image description here

The following is after using @Warren Weckesser comment:

toy_signal_filtered = np.sum([sosfilt(sos, toy_signal) for sos in sos_list], axis=0)

enter image description here

Here is an example where a narrow band is used:

enter image description here


Solution

  • Easier and recommended method is what Warren wrote in comments. Just calculate sum of separately band-pass filtered signals.

    That being said, for someone who wants to create and apply single multi-band filter, he can try to achieve this by combining filters:

    • lowpass (to cut everything above last pass-filter),
    • highpass (to cut everything below first pass-filter),
    • N-1 band-stop filters, where N being number of pass-bands (to cut parts in-between filters).

    It may be difficult though to make it stable (be careful with filter orders) and harder to make it steep.

    Found it interesting and tried myself:

    from scipy.signal import butter, lfilter
    import matplotlib.pyplot as plt
    from scipy.fft import fft
    import numpy as np
    
    
    def multi_band_filter(bands, subfilter_order=5):
        # high-pass filter
        nyq = 0.5 * fs
        normal_cutoff = bands[0][0] / nyq
        b, a = butter(subfilter_order, normal_cutoff, btype='highpass', analog=False)
        all_b = [b]
        all_a = [a]
    
        # band-stop filters
        for idx in range(len(bands) - 1):
            normal_cutoff1 = bands[idx][1] / nyq
            normal_cutoff2 = bands[idx+1][0] / nyq
            b, a = butter(subfilter_order, [normal_cutoff1, normal_cutoff2], btype='bandstop', analog=False)
            all_b.append(b)
            all_a.append(a)
    
        # low-pass filter
        normal_cutoff = bands[-1][1] / nyq
        b, a = butter(subfilter_order, normal_cutoff, btype='lowpass', analog=False)
        all_b.append(b)
        all_a.append(a)
    
        # combine filters:
        combined_a = all_a[0]
        for a in all_a[1:]:
            combined_a = np.convolve(a, combined_a)
        combined_b = all_b[0]
        for b in all_b[1:]:
            combined_b = np.convolve(b, combined_b)
    
        return combined_b, combined_a
    
    
    bands = [[400, 700], [1000, 1500]]
    
    fs = 8000
    time = np.arange(0, 1 - 0.5/fs, 1/fs)
    signal_to_filter = np.sum([np.sin(2 * np.pi * (freq + 0.01 * np.random.random()) * time + np.pi*np.random.random()) for freq in range(10, 3800)], axis=0)
    
    b, a = multi_band_filter(bands)
    filtered_signal = lfilter(b, a, signal_to_filter)
    original_spectrum = fft(signal_to_filter)
    filtered_signal_spectrum = fft(filtered_signal)
    
    plt.figure(figsize=(16, 10))
    plt.plot(np.linspace(0, fs, len(original_spectrum)), np.abs(original_spectrum), color='b')
    plt.plot(np.linspace(0, fs, len(filtered_signal_spectrum)), np.abs(filtered_signal_spectrum), color='orange')
    plt.xlim([0, 4000])
    plt.show()
    

    SOS version

    from scipy.signal import butter, sosfilt, freqz
    import matplotlib.pyplot as plt
    from scipy.fft import fft
    import numpy as np
    
    
    def multi_band_filter(bands, subfilter_order=5):
        # high-pass filter
        nyq = 0.5 * fs
        normal_cutoff = bands[0][0] / nyq
        sos = butter(subfilter_order, normal_cutoff, btype='highpass', analog=False, output='sos')
        all_sos = [sos]
    
        # band-stop filters
        for idx in range(len(bands) - 1):
            normal_cutoff1 = bands[idx][1] / nyq
            normal_cutoff2 = bands[idx+1][0] / nyq
            sos = butter(subfilter_order, [normal_cutoff1, normal_cutoff2], btype='bandstop', analog=False, output='sos')
            all_sos.append(sos)
    
        # low-pass filter
        normal_cutoff = bands[-1][1] / nyq
        sos = butter(subfilter_order, normal_cutoff, btype='lowpass', analog=False, output='sos')
        all_sos.append(sos)
    
        # combine filters:
        combined_sos = np.vstack(all_sos)
        return combined_sos
    
    
    bands = [[400, 700], [1000, 1500]]
    
    fs = 8000
    time = np.arange(0, 1 - 0.5/fs, 1/fs)
    signal_to_filter = np.sum([np.sin(2 * np.pi * (freq + 0.01 * np.random.random()) * time + np.pi*np.random.random()) for freq in range(10, 3800)], axis=0)
    
    sos = multi_band_filter(bands)
    filtered_signal = sosfilt(sos, signal_to_filter)
    original_spectrum = fft(signal_to_filter)
    filtered_signal_spectrum = fft(filtered_signal)
    
    plt.figure(figsize=(16, 10))
    plt.plot(np.linspace(0, fs, len(original_spectrum)), np.abs(original_spectrum), color='b')
    plt.plot(np.linspace(0, fs, len(filtered_signal_spectrum)), np.abs(filtered_signal_spectrum), color='orange')
    plt.xlim([0, 4000])
    plt.show()
    
    w, h = freqz(b, a)
    freq_domain = np.linspace(0, fs/2, len(w))
    plt.figure(figsize=(16, 10))
    plt.plot(freq_domain, 20 * np.log10(abs(h)), 'b')
    plt.show()
    

    enter image description here enter image description here

    As you can see, the slope of the filter is not very steep.