Search code examples
pythontensorflowconv-neural-network

MaxPooling operation on temporal data - select signal with the highest amplitude


Is there any clean way to perform Maxpooling operation on temporal data (i.e. signal with highest amplitude will be the output).

For example,

# sample four sin signals
a = 2*tf.math.sin(tf.linspace(0, 10, 200))
b = 0.1*tf.math.sin(2*tf.linspace(0, 10, 200))
c = 3*tf.math.sin(0.5*tf.linspace(0, 10, 200))
d = 1*tf.math.sin(5*tf.linspace(0, 10, 200))
# stack the signals
data = tf.stack([a, b, c, d], -1)
# reshape to appropriate timeseries of 2D feature-maps
# (batch_size, sequence length, feature_dim1, feature_dim2, channels)
data = tf.reshape(data, [1, 200, 2, 2, 1])

data will look something like this:

enter image description here

Now, I want to perform something similar to MaxPooling2D((2,2)) operation on data to get only c (as it has the highest amplitude). Clearly, we cannot use MaxPooling3D and TimeDistributed layers directly, as they will perform pooling at each timestep. I tried my luck with alternatives using tf.math.reduce_max() and tf.nn.max_pool_with_argmax but they were not straight-forward.

Any suggestions or comments is appreciated. Thanks in advance :)


Solution

  • Here is my implementation for the above question,

    def temporal_max_pooler(signal_stack):
    
      #signal_stack (bs, T, f, f, d)
      overall_stack = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
      for ch in range(signal_stack.shape[-1]):
        ch_signals = signal_stack[:,:,:,:,ch:ch+1]
        patches = tf.extract_volume_patches(ch_signals, (1, 1, 2, 2, 1), 
                                            (1, 1, 2, 2, 1), 'VALID')
        patches = tf.transpose(patches, [0, 1, 4, 2, 3])
        (s0, s1, s2, s3, s4) = patches.shape
        patches = tf.reshape(patches, [s0, s1, s2//2, s2//2, s4*s3])
        ch_stack = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
        for p in range(patches.shape[-1]):
          p_signals = tf.reshape(patches[:,:,:,:,p],
                                 (patches.shape[0], patches.shape[1], -1))
          max_amps = tf.math.reduce_max(p_signals, 1)
          where_is_max = tf.math.argmax(max_amps, -1)
          winners = tf.gather(p_signals, where_is_max, axis=-1, batch_dims=1)
          ch_stack.write(ch_stack.size(), winners)
        ch_stack = tf.transpose(ch_stack.stack(), [1, 2, 0])
        n_d = tf.math.sqrt(tf.cast(ch_stack.shape[-1], 'float32'))
        n_d = tf.cast(n_d, 'int32')
        ch_stack = tf.reshape(ch_stack, [ch_stack.shape[0], ch_stack.shape[1],
                                         n_d, n_d])
        overall_stack.write(overall_stack.size(), ch_stack)
      overall_stack = overall_stack.stack()
      return tf.transpose(overall_stack, [1, 2, 3, 4, 0])