Search code examples
tensorflowmatrixkerastensorflow2.0matrix-multiplication

Matrix multiplication with transpose with Tensorflow


Trying to do MatrixMultiplication in TF

import tensorflow as tf
a1 = tf.constant(tf.random.normal(shape=(5,4,64)))
tf.matmul(a1,a1,transpose_b=True)

This works perfectly fine, but if I transpose the input a1 manually like the following, I get an error:

tf.matmul(a1,tf.transpose(a1))

Error:

InvalidArgumentError: In[0] and In[1] must have compatible batch dimensions: [5,4,64] vs. [64,4,5] [Op:BatchMatMulV2]

Documentation:

transpose_b: If True, b is transposed before multiplication.

So I don't understand the difference, any suggestions will be helpful.


Solution

  • To compute the batch matrix multiplication you need to ensure the following format for the 3D tensor. Check the 3-D tensor matrix multiplication.

    [batchs, h`, w`] @ [batchs, w``, h``]
    

    So, in your above case, it should be

    import tensorflow as tf
    
    a1 = tf.constant(tf.random.normal(shape=(5,4,64)))
    
    a1.shape, tf.transpose(a1, perm=[0, 2, 1]).shape
    (TensorShape([5, 4, 64]), TensorShape([5, 64, 4]))
    
    # swape the height and width - not batch axis 
    tf.matmul(a1, tf.transpose(a1, perm=[0, 2, 1])).shape
    TensorShape([5, 4, 4])