Search code examples
pythonfunctiontensorflowmatrixmatrix-multiplication

what is the difference between matrix multiplication methods and functions in tensorflow?


What are the differences between these three ways to multiply two matrices in tensorflow? the three ways are :

  1. @
  2. tf.tensordot()
  3. tf.matmul()

I have tested them and they give the same result. but I wanted to know if there is any underlying difference.


Solution

  • Let us understand this with below example, I have taken two matrix a, b to perform these functions:

    import tensorflow as tf
    
    a = tf.constant([[1, 2],
                     [3, 4]])
    b = tf.constant([[1, 1],
                     [1, 1]]) # or `tf.ones([2,2])`
    

    tf.matmul(a,b) and (a @ b) - both performs matrix mutiplication

    print(tf.matmul(a, b), "\n")   # matrix - multiplication
    

    Output:

    tf.Tensor(
    [[3 3]
     [7 7]], shape=(2, 2), dtype=int32) 
    

    You can see the same output here as well for same matrix:

    print(a @ b, "\n") # @ used as matrix_multiplication operator
    

    Output:

    tf.Tensor(
    [[3 3]
     [7 7]], shape=(2, 2), dtype=int32) 
    

    tf.tensordot() - Tensordot (also known as tensor contraction) sums the product of elements from a and b over the indices specified by axes .

    if we take axes=0 (scalar, no axes):

    print(tf.tensordot(a, b, axes=0), "\n")
    #One by one each element(scalar) of first matrix multiply with all element of second matrix and keeps output in separate matrix for each element multiplication.
    

    Output:

    tf.Tensor(
    [[[[1 1]
       [1 1]]
    
      [[2 2]
       [2 2]]]
    
    
     [[[3 3]
       [3 3]]
    
      [[4 4]
       [4 4]]]], shape=(2, 2, 2, 2), dtype=int32) 
    

    if we change axes=1:

    print(tf.tensordot(a, b, axes=1), "\n")  
    # performs matrix-multiplication 
    

    Output:

    tf.Tensor(
    [[3 3]
     [7 7]], shape=(2, 2), dtype=int32) 
    

    and for axes=2:

    print(tf.tensordot(a, b, axes=2), "\n") 
    # performs element-wise multiplication,sums the result into scalar.
    

    Output:

    tf.Tensor(10, shape=(), dtype=int32) 
    

    You can explore more about tf.tensordot() and basic details on axes in given links.