Search code examples
tensorflowdeep-learningspeech-recognitiontf.kerastransformer-model

What is attention penalty in speech transformer paper? (updated)


github: https://github.com/sephiroce/tfsr/tree/exprimental

I'm trying to reproduce recognition accuracies described in the speech transformer paper [1]. The attention penalty is a technique I could not fully understand. This is the description of the attention penalty in the paper.

"In addition, we encouraged the model attending to closer positions by adding bigger penalty on the attention weights of more distant position-pairs."

I understood as it means adding smaller negative values for more away from the diagonal on scaled attention logits (before masking) except for the first multi-head attention in decoders.

This is a code snippet for computing attention weights.

  # Q * trans(K): (..., seq_len_q, seq_len_k)
  matmul_qk = tf.matmul(query, key, transpose_b=True)

  # scaled matmul_qk: ( Q * trans(K) ) / sqrt(d_k)
  dimension_of_key = tf.cast(tf.shape(key)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dimension_of_key)

  # add the mask to the scaled tensor
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

  # Adding penalty to attention weights and linearly re-normalize it.
  if attention_penalty is not None and att_penalty_scale > 0:
    attention_weights += (attention_penalty * att_penalty_scale)
    attention_weights += tf.math.abs(tf.math.reduce_min(attention_weights))
    inv_sum = 1 / tf.math.reduce_sum(attention_weights, axis=-1)
    attention_weights = tf.einsum('ijlm,ijl->ijlm', attention_weights, inv_sum)

The source code snippet below is for creating an attention penalty matrix. I could not find any efficient way to create an attention penalty matrix for the second multi-head attention weights in decoders since the attention maps are not diagonal. Thus first I am trying to apply the attention penalty to encoders. The source code assigns linearly bigger penalties for more distant elements from diagonal.
There are two hyper-parameters such as an attention_penalty_scale (this is similar to penalty_values which Jindřich suggested) and a width of the diagonal line.
I might be able to add an option such as stripe_step_size. Currently stripe_step_size can be interpreted as 1.

def create_attention_penalty(inp_len, tar_len, num_heads, attention_penalty_width):
  max_inp_len = tf.cast(tf.math.reduce_max(inp_len), tf.int32)
  n_batch = tf.shape(inp_len)[0]

  enc_att_penalty = tf.ones([n_batch, num_heads, max_inp_len, max_inp_len])

  accum = tf.zeros(([n_batch, num_heads, max_inp_len, max_inp_len]))
  for i in range(attention_penalty_width - 1, max_inp_len - 1):
    accum += tf.linalg.band_part(enc_att_penalty, i, i, name=None) - 1

  enc_att_penalty = accum

  return enc_att_penalty, None

Even though I implemented as I understand, I could not gain any accuracy improvement. And there is another down-side of this implementation. The training speed was getting slower.

Q) How to efficiently apply this attention penalty method for square and non-square attention weights?

Reference
[1] Linhao Dong, Shuang Xu, Bo Xu, Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition, ICASSP 2018, https://ieeexplore.ieee.org/document/8462506


Solution

  • I think you understand it well. They probably did a stripe around the diagonal, something like:

    attention_penalty = (1 - tf.linalg.band_part(scaled_attention_logits, stripe_size, stripe_size)) * penalty
    

    However, you probably need to experiment more with what the strip_size and penalty_values should be because the paper does not say much. Or you can try to write to the authors.