I am trying to implement local-p attention based on this paper: https://arxiv.org/pdf/1508.04025.pdf Specifically, equation (9) derives a alignment position based on taking the sigmoid of some non-linear functions, and then multiplying the resultant with number of timesteps. As sigmoid returns values between 0 and 1, this multiplication yields a valid index between 0 and number of timesteps. I can soft round this to infer the predicted position, however, I couldn't find a way to convert this to a integer to use within slicing/indexing operations since tf.cast() is not differentiable. Another problem is that the derived positions are in shape (B, 1), and hence one aligned position for each example in the batch. See below to understand these operations:
"""B = batch size, S = sequence length (num. timesteps), V = vocabulary size, H = number of hidden dimensions"""
class LocalAttention(Layer):
def __init__(self, size, window_width=None, **kwargs):
super(LocalAttention, self).__init__(**kwargs)
self.size = size
self.window_width = window_width # 2*D
def build(self, input_shape):
self.W_p = Dense(units=input_shape[2], use_bias=False)
self.W_p.build(input_shape=(None, None, input_shape[2])) # (B, 1, H)
self._trainable_weights += self.W_p.trainable_weights
self.v_p = Dense(units=1, use_bias=False)
self.v_p.build(input_shape=(None, None, input_shape[2])) # (B, 1, H)
self._trainable_weights += self.v_p.trainable_weights
super(Attention, self).build(input_shape)
def call(self, inputs):
sequence_length = inputs.shape[1]
## Get h_t, the current (target) hidden state ##
target_hidden_state = Lambda(function=lambda x: x[:, -1, :])(inputs) # (B, H)
## Get h_s, source hidden states ##
aligned_position = self.W_p(target_hidden_state) # (B, H)
aligned_position = Activation('tanh')(aligned_position) # (B, H)
aligned_position = self.v_p(aligned_position) # (B, 1)
aligned_position = Activation('sigmoid')(aligned_position) # (B, 1)
aligned_position = aligned_position * sequence_length # (B, 1)
Let's say the aligned_position
tensor has elements [24.2, 15.1, 12.3] for a batch size = B = 3 for simplification. Then, the source hidden states are derived from input hidden states (B=3, S, H) such that for the first example we take timesteps starting from 24, hence something along the lines of first_batch_states = Lambda(function=lambda x: x[:, 24:, :])(inputs)
and so on. Note that the implementation of local-p attention is more complicated than this, but I simplified it here. Hence, the main challenge is converting 24.2 to 24 without losing differentiability, or using some sort of a mask operation to get the indexes through dot product. The mask operation is preferred, as we will have to do this for each example in batch, and having a loop inside a custom Keras layer is not neat. Do you have any ideas on how to accomplish this task? I will appreciate any answers and comments!
There are two ways I found to go about solving this problem.
gaussian_estimation = lambda s: tf.exp(-tf.square(s - aligned_position) /
(2 * tf.square(self.window_width / 2)))
gaussian_factor = gaussian_estimation(0)
for i in range(1, sequence_length):
gaussian_factor = Concatenate()([gaussian_factor, gaussian_estimation(i)])
# Adjust weights via gaussian_factor: (B, S*) to allow differentiability
attention_weights = attention_weights * gaussian_factor # (B, S*)
It should be noted that there is no hard slicing operation involved here, only simple adjusting according to distance.
aligned_position = self.W_p(inputs) # (B, S, H)
aligned_position = Activation('tanh')(aligned_position) # (B, S, H)
aligned_position = self.v_p(aligned_position) # (B, S, 1)
aligned_position = Activation('sigmoid')(aligned_position) # (B, S, 1)
## Only keep top D values out of the sigmoid activation, and zero-out the rest ##
aligned_position = tf.squeeze(aligned_position, axis=-1) # (B, S)
top_probabilities = tf.nn.top_k(input=aligned_position,
k=self.window_width,
sorted=False) # (values:(B, D), indices:(B, D))
onehot_vector = tf.one_hot(indices=top_probabilities.indices,
depth=sequence_length) # (B, D, S)
onehot_vector = tf.reduce_sum(onehot_vector, axis=1) # (B, S)
aligned_position = Multiply()([aligned_position, onehot_vector]) # (B, S)
aligned_position = tf.expand_dims(aligned_position, axis=-1) # (B, S, 1)
source_hidden_states = Multiply()([inputs, aligned_position]) # (B, S*=S(D), H)
## Scale back-to approximately original hidden state values ##
aligned_position += 1 # (B, S, 1)
source_hidden_states /= aligned_position # (B, S*=S(D), H)
It should be noted that here we are instead applying the dense layers to all hidden source states to get a shape of (B,S,1)
instead of (B,1)
for aligned_position
. I believe this is as close as we can get to what the paper suggests.
Anybody who is trying to implement attention mechanisms can check my repo https://github.com/uzaymacar/attention-mechanisms. Layers here are designed for many-to-one sequence tasks, but can be adapted to other forms with minor tweaks.