Search code examples
linear-algebratensorflow-xla

Why Is Scalar Multiply Before Einsum Faster?


In the TensorFlow Keras implementation of Multi-Head Attention, instead of evaluating the numerator first like in

enter image description here

they evaluate Q/√dₖ first and put comment

Note: Applying scalar multiply at the smaller end of einsum improves XLA performance, but may introduce slight numeric differences in the Transformer attention head.

How is it faster this way? Wouldn't the division after einsum be equally as fast?


Solution

  • What the comment suggest is that the the number of elements in key is less than the number of elements in query or attention_scores in the following equation.

    attention_scores = tf.einsum(self._dot_product_equation, key, query)
    

    Given the dimensions

                query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
                key: Projected key `Tensor` of shape `(B, S, N, key_dim)`.
    

    Assuming that _dot_product_equation is simply doing the batched matrix multiplication, if Q is T x N, and Q is S x N, the product Q @ K.T is T x S, if S > N the number of multiplications is expected to be smaller on the left.

    But either way that should not be the dominant part except if S > T * N (or XLA has a bug).