Search code examples
tensorflowmachine-learningdeep-learninglstmsignal-processing

Using an LSTM model with 2 simultaneous signals to denoise one of the signals


in the medical imaging field, it is common to have two signals, where artifact in one signal bleeds into the other. I'm exploring whether a deep learning model can be used with these signals to regress out the artifact (this is for a student project and for my own curiosity).

Here is a grossly oversimplified example: We are given two signals, signal_X, and signal_Y. Signal_X is a flat line with an occasional square wave pulse. Signal_Y is a pure sine wave. When signal_X pulses, it leaks into signal_Y such that signal_Y's amplitude increases. We consider this to be artifact.

The goal is to use an LSTM model to learn the relationship between signal_X and signal_Y, so that we can regress out the "noise" injected into signal_Y when signal_X pulses.

Below I am training the model on (x,y) pairs from signal_X and signal_Y. Then I fit that model to signal_X itself to get predicted values. I then define the corrected signal_Y as

corrected_signal_Y = Y_train - pred_Y

The result isn't great, but I feel that I may be way off base here. Am I actually just training my model to predict signal_Y (artifact included)? If so, then my calculation of "corrected_Y" must be conceptually wrong.

Minimal working example

You can see use the code below, or just run it in this colab notebook

YOU CAN SKIP THIS: This first code block is just to generate the signals

import numpy as np
import matplotlib.pyplot as plt

def simulate_signals(N, sine_amplitude, sine_frequency, square_amplitude, square_duration, pulse_prob):
  """simulate a square wave (signal X) and a sine wave (signal Y)


  Parameters
  ----------
  N : int
    the length of the signals, in samples.
  sin_amplitude: int | float
    The amplitude of the sine wave
  sine_frequency: int
    Frequency of the sine wave
  square_amplitude : int
    Amplitude of the artifact square wave in signal Y
  square_duration: int
    Duration of the artifact square wave in signal Y
  pulse_prob: float
    Probability of a square wave pulse occurring in signal X at each time step

  Returns
  -------
  Signal_X, signal_Y
  """
  # Generate time axis
  t = np.arange(N)

  # Simulate signal X with occasional square wave pulses
  signal_X = np.zeros(N)
  for i in range(N):
      if np.random.rand() < pulse_prob:
          signal_X[i:i+square_duration] = 1

  # Simulate signal Y as a pure sine wave with added artifact square wave component
  signal_Y_sine = sine_amplitude * np.sin(2 * np.pi * sine_frequency * t)
  signal_Y_square = square_amplitude * signal_X
  signal_Y = signal_Y_sine + signal_Y_square

  return signal_X, signal_Y

# Parameters for the simulation
N = 1000
sine_amplitude = 1.0
sine_frequency = 0.05
square_amplitude = 0.5
square_duration = 20
pulse_prob = 0.02

# Simulate signals X and Y
signal_X, signal_Y = simulate_signals(N, sine_amplitude, sine_frequency,
                                      square_amplitude, square_duration,
                                      pulse_prob)

# Plot the signals
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(signal_X, label='Signal X')
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.legend()
plt.xlim(0, 1000)


plt.subplot(2, 1, 2)
plt.plot(signal_Y, label='Signal Y')
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.legend()
plt.xlim(0, 1000)

plt.tight_layout()
plt.show()

Image of signal_X and signal_Y

THIS IS THE CODE BLOCK WITH THE MODEL

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dropout, LSTM

from sklearn.preprocessing import StandardScaler


scaler = StandardScaler()
# normalize, add empty dimension, downsample by factor of 10
X = scaler.fit_transform(signal_X.reshape(-1, 1))
Y = scaler.fit_transform(signal_Y.reshape(-1, 1))
print(f"X, Y shape: {X.shape}, {Y.shape}")

nb_time = 100
n_blocks = 50
step = 10
model = Sequential()


# Split dataset
n = int(X.shape[0] / nb_time) # i.e 1000 / 100
X_train = X[:nb_time * n, :].reshape((-1, nb_time, X.shape[-1]), order='C')
Y_train = Y[:nb_time * n, :].reshape((-1, nb_time, Y.shape[-1]), order='C')


# LSTM layer accepts a 3D array as input which has a shape of (n_sample, n_timesteps, n_features)
model.add(LSTM(n_blocks, input_shape=(nb_time,  X_train.shape[-1]),
               return_sequences=True))
model.add(Dropout(0.5))
model.add(LSTM(Y.shape[1],
               return_sequences=True))
model.add(Dropout(0.5))


# try this optimizer instead of 'adam'
adagrad = tf.keras.optimizers.Adagrad(learning_rate=1)
model.compile(loss='mean_squared_error', optimizer=adagrad)
print(model.summary())


model.fit(X_train, Y_train, epochs=50,
          validation_split=0.2, batch_size=1, verbose=2)


# Visualize loss
fig, ax = plt.subplots()
ax.plot(model.history.history["loss"], label="Training Loss") # blue
ax.plot(model.history.history["val_loss"], label="Validation Loss") # orange
plt.legend()

Visualization of the Loss

# Make predictions
predicted_Y = model.predict(X_train)

# Deno[enter image description here](https://i.sstatic.net/ECfJP.png)rmalize predictions
predicted_Y_denormalized = scaler.inverse_transform(predicted_Y.reshape(-1, 1))

# Correct the signal: I'm unsure whether this is conceptually correct
corrected_signal_Y = Y_train.reshape((-1,1)) - predicted_Y_denormalized

# Plot original signal_Y and predicted signal_Y
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(Y_train.reshape((-1,1)), label='Original Signal Y')
ax.plot(predicted_Y_denormalized, label='Predicted Signal Y')
ax.plot(corrected_signal_Y, label="Corrected Signal_Y")
ax.set_xlabel('Time')
ax.set_ylabel('Amplitude')
ax.legend()
plt.show()

Plot of signal_Y, pred_Y, and corrected_Y


Solution

  • I think you need to supply the square wave signal_X and the noised signal signal_Y as input features (or just one of those, if you prefer). Use the clean signal signal_Y_sine as the target. The model should learn to reproduce the clean signal from the input feature(s).

    enter image description here

    signal_X, signal_Y, signal_Y_target = simulate_signals(...)
    
    # Plot the signals
    f, ax = plt.subplots(figsize=(10, 2.5))
    ax.plot(signal_X, 'r-', lw=1.3, label='signal_X')
    ax.plot(signal_Y, 'tab:green', lw=1.3, label='signal_Y')
    ax.plot(signal_Y_target, lw=2, color='tab:brown', label='Target: signal_Y_sine')
    ax.set_ylabel('Amplitude')
    ax.set_xlabel('n')
    f.legend(loc='upper center', ncol=3)