Search code examples
pythonnumpymatplotlibscipyeeglab

Difficulty plotting spectrogram for EEG data in Python


I'm doing some task imitates some plots of Regulation of brain cognitive states through auditory, gustatory, and olfactory stimulation with wearable monitoring - but I have a problem when I was trying to plot a spectrogram of EEG in Python with the library scipy.

I'm trying to plot a spectrogram for EEG data recorded from a Muse headband dataset. Here are the steps I've followed based on the instructions provided:

  1. Preprocessing: I applied a high-pass filter above 1 Hz and a low-pass filter below 50 Hz to the raw EEG signals using signal.butter. Then, I attempted to downsample the signals from 256 Hz to 128 Hz using signal.resample and signal.decimate, but I'm encountering a ValueError: "Length of values does not match length of index."

  2. Plotting Spectrogram: I used signal.spectrogram to compute the spectrogram for channels TP9 and TP10. However, when plotting, only the axes are visible, and there's no EEG signal.

# Extract the desired channels (TP9 and TP10)

channels = ['TP9', 'TP10']  
eeg_data = eeg_df[channels]

# Convert the timestamps to time in seconds

sampling_rate = 256  # Sampling rate of the EEG data in Hz
time_seconds = eeg_df['timestamps'] / sampling_rate

# Preprocess the EEG signals

for channel in channels:
    # Apply high-pass filter above 1 Hz and low-pass filter below 50 Hz
    b, a = signal.butter(4, [1, 50], btype='bandpass', fs=sampling_rate)
    eeg_data[channel] = signal.filtfilt(b, a, eeg_data[channel])
    
    # Downsample the signal from 256 to 128 Hz

    new_length = len(eeg_data[channel]) // 2
    eeg_data[channel] = signal.resample(eeg_data[channel], new_length)

# Define the parameters for the spectrogram

window = 'hann'  # Windowing function

nperseg = 128    # Length of each segment

noverlap = 64    # Overlap between segments

nfft = 128       # Number of points for the FFT

# Compute the spectrogram for each channel

frequencies, times, spectrogram_TP9 = signal.spectrogram(eeg_data['TP9'], fs=sampling_rate / 2, window=window,nperseg=nperseg, noverlap=noverlap, nfft=nfft)
frequencies, times, spectrogram_TP10 = signal.spectrogram(eeg_data['TP10'], fs=sampling_rate / 2, window=window,nperseg=nperseg, noverlap=noverlap, nfft=nfft)

# Plot the spectrogram for TP9

plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.pcolormesh(times, frequencies, 10 * np.log10(spectrogram_TP9), shading='gouraud', cmap='viridis')
plt.title('Spectrogram for TP9')
plt.xlabel('Time (seconds)')
plt.ylabel('Frequency (Hz)')
plt.colorbar(label='Power/Frequency (dB/Hz)')
plt.ylim(0, 30)  # Limit the y-axis to frequencies up to 30 Hz

# Plot the spectrogram for TP10

plt.subplot(2, 1, 2)
plt.pcolormesh(times, frequencies, 10 * np.log10(spectrogram_TP10), shading='gouraud', cmap='viridis')
plt.title('Spectrogram for TP10')
plt.xlabel('Time (seconds)')
plt.ylabel('Frequency (Hz)')
plt.colorbar(label='Power/Frequency (dB/Hz)')
plt.ylim(0, 30)  # Limit the y-axis to frequencies up to 30 Hz

plt.tight_layout()
plt.show()

I suspect the issue might be related to mismatched lengths after downsampling or incorrect usage of the signal processing functions. Could someone please help me troubleshoot this issue and provide guidance on how to correctly plot the spectrogram for the EEG data?, I attached in images my error, and also my data set information and values.

Any assistance or suggestions would be greatly appreciated! Thank you!

EEG figure desired:

EEG figure desired


Solution

  • I've sought to create the desired plot using similar data. I started by visualising the original and filtered signals, before creating the spectrograms.

    I think the error you were seeing was because you were assigning a downsampled signal to the original dataframe, which won't work as they are different lengths. To get round this, I assigned the downsampled data to a new dataframe eeg_downsampled.

    Original and processed channels: enter image description here

    Spectrograms: enter image description here


    import pandas as pd
    import numpy as np
    from matplotlib import pyplot as plt
    
    from scipy import signal
    
    #
    # Load some test data and modify it to look more like the OP's data
    #
    eeg_df = pd.read_csv('../musedata_meditation.csv')
    eeg_df = df.rename(columns={'RAW_TP9': 'TP9', 'RAW_TP10': 'TP10'})
    eeg_df['sample_number'] = range(len(eeg_df))
    
    # Sampling rate of the EEG data in Hz
    sampling_rate = 256  
    
    # Convert the sample index to time in seconds
    # This means you need to divide the sample *number* by the sampling rate
    eeg_df['time_seconds'] = eeg_df['sample_number'] / sampling_rate
    
    # Extract the desired channels (TP9 and TP10)
    keep_channels = ['TP9', 'TP10']
    eeg_data = eeg_df[keep_channels + ['time_seconds']].copy()
    
    #
    # Preprocess the EEG signals
    #
    
    #Define the pre-preprocessing parameters
    downsample_factor = 2
    bp_order = 4
    bp_passband = [1, 50]
    downsample_factor = 2
    
    eeg_downsampled = pd.DataFrame()
    downsampled_length = len(eeg_data) // downsample_factor
    
    for channel in keep_channels:
        filtered_name = channel + '_bp'
        # Apply high-pass filter above 1 Hz and low-pass filter below 50 Hz
        b, a = signal.butter(bp_order, bp_passband, btype='bandpass', fs=sampling_rate)
        eeg_data[filtered_name] = signal.filtfilt(b, a, eeg_data[channel])
        
        # Downsample the signal from 256 to 128 Hz
        # Put the shorter signals in a new dataframe "eeg_downsampled"
        eeg_downsampled[channel] = signal.resample(eeg_data[filtered_name], downsampled_length)
    eeg_downsampled['time_seconds'] = np.arange(new_length) / (sampling_rate / downsample_factor)
    
    #
    # View the original and processed data
    #
    plot_height = 2
    n_plots = 3
    f, axs = plt.subplots(nrows=n_plots, ncols=1,
                          figsize=(11, plot_height * n_plots),
                          layout='constrained', sharex=True)
    
    plot_slice = slice(0, 2500) #window of data to plot
    
    #Original data
    eeg_data[plot_slice].plot(
        y=keep_channels, x='time_seconds',
        ylabel='signal', xlabel='time (s)',
        linewidth=1, ax=axs[0],
        title=f'original data (samples {plot_slice.start} to {plot_slice.stop})'
    )
    
    #Filtered
    eeg_data[plot_slice].plot(
        y=[chan + '_bp' for chan in keep_channels], x='time_seconds',
        ylabel='signal', xlabel='time (s)', legend=False,
        linewidth=1, ax=axs[1], title=f'bandpass {bp_passband}Hz'
    )
    
    #Downsampled
    eeg_downsampled[plot_slice.start:plot_slice.stop//downsample_factor].plot(
        y=keep_channels, x='time_seconds',
        ylabel='signal', xlabel='time (s)', legend=False,
        linewidth=1, ax=axs[2], title=f'downsampled {downsample_factor}x'
    )
    
    
    # Define the parameters for the spectrogram
    window = 'hann'  # Windowing function
    nperseg = 128    # Length of each segment
    noverlap = 64    # Overlap between segments
    nfft = 128       # Number of points for the FFT
    
    # Compute the spectrogram for each channel
    spectrograms = {}
    for channel in keep_channels:
        frequencies, times, spectrogram = signal.spectrogram(
            eeg_downsampled[channel],
            fs=sampling_rate / downsample_factor,
            window=window, nperseg=nperseg, noverlap=noverlap, nfft=nfft
        )
        spectrograms[channel] = spectrogram
    
    # Plot the spectrograms
    f, axs = plt.subplots(nrows=2, ncols=1, figsize=(11, 5),
                          layout='constrained', sharex=True)
    f.suptitle('Spectrograms')
    vmin, vmax = -10, 30
    
    for channel, ax in zip(keep_channels, axs):
        im = ax.pcolormesh(
            times, frequencies, 10 * np.log10(spectrograms[channel]),
            vmin=vmin, vmax=vmax, #comment out for auto-range
            shading='nearest', cmap='viridis'
        )
        
        ax.set_ylim(0, 30)  # Limit the y-axis to frequencies up to 30 Hz
        ax.set_ylabel(('Left' if channel=='TP9' else 'Right') + f' ({channel})\nfrequency (Hz)')
        
        cbar = f.colorbar(im, aspect=8, pad=0.01, label='power/frequency (dB/Hz)')
        
    ax.set_xlabel('time (s)')