Trying to do MatrixMultiplication
in TF
import tensorflow as tf
a1 = tf.constant(tf.random.normal(shape=(5,4,64)))
tf.matmul(a1,a1,transpose_b=True)
This works perfectly fine, but if I transpose
the input a1
manually like the following, I get an error:
tf.matmul(a1,tf.transpose(a1))
Error:
InvalidArgumentError: In[0] and In[1] must have compatible batch dimensions: [5,4,64] vs. [64,4,5] [Op:BatchMatMulV2]
Documentation:
transpose_b: If True, b is transposed before multiplication.
So I don't understand the difference, any suggestions will be helpful.
To compute the batch matrix multiplication you need to ensure the following format for the 3D tensor. Check the 3-D
tensor matrix multiplication.
[batchs, h`, w`] @ [batchs, w``, h``]
So, in your above case, it should be
import tensorflow as tf
a1 = tf.constant(tf.random.normal(shape=(5,4,64)))
a1.shape, tf.transpose(a1, perm=[0, 2, 1]).shape
(TensorShape([5, 4, 64]), TensorShape([5, 64, 4]))
# swape the height and width - not batch axis
tf.matmul(a1, tf.transpose(a1, perm=[0, 2, 1])).shape
TensorShape([5, 4, 4])