Search code examples
pythontensorflowmatrix-multiplicationtensor

how to use matrix multiplication to implement that covnert (Batch, a, b) tensor * (Batch, b) tensor into (Batch, a) tensor in tensorflow1.10


For example,

# Batch = 5, a = 25, b = 2
# tensor t1 shape: (Batch, a, b)
# tensor t2 shape: (Batch, b)
# tensor res shape: (Batch, a)
print(t1)
<tf.Tensor: id=466, shape=(2, 25, 2), dtype=int32, numpy=
array([[[ 1, 26],
        [ 2, 27],
        [ 3, 28],
        [ 4, 29],
        [ 5, 30],
        [ 6, 31],
        [ 7, 32],
        [ 8, 33],
        [ 9, 34],
        [10, 35],
        [11, 36],
        [12, 37],
        [13, 38],
        [14, 39],
        [15, 40],
        [16, 41],
        [17, 42],
        [18, 43],
        [19, 44],
        [20, 45],
        [21, 46],
        [22, 47],
        [23, 48],
        [24, 49],
        [25, 50]],

       [[ 1, 26],
        [ 2, 27],
        [ 3, 28],
        [ 4, 29],
        [ 5, 30],
        [ 6, 31],
        [ 7, 32],
        [ 8, 33],
        [ 9, 34],
        [10, 35],
        [11, 36],
        [12, 37],
        [13, 38],
        [14, 39],
        [15, 40],
        [16, 41],
        [17, 42],
        [18, 43],
        [19, 44],
        [20, 45],
        [21, 46],
        [22, 47],
        [23, 48],
        [24, 49],
        [25, 50]]], dtype=int32)>

print(t2)
<tf.Tensor: id=410, shape=(2, 2), dtype=int32, numpy=
array([[1, 0],
       [1, 0]], dtype=int32)>

# after matrix multiplication
print(res)
<tf.Tensor: id=474, shape=(2, 25), dtype=int32, numpy=
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25]], dtype=int32)>

The way I think is to use matrix multiplication to only keep part like before, but it is hard for me to implement it.
If not mind could anyone help me?


Solution

  • Starting with

    import tensorflow as tf # 2.6.0
    Batch, a, b = 5, 25, 2
    t1 = tf.random.normal((Batch, a, b))
    t2 = tf.random.normal((Batch, b));
    

    If I understand correctly you want t3[b,i] = sum(t2[b,i,j], t2[b,j]) this can be described in a straight forward manner using Einstein summation

    t3 = tf.einsum('bij,bj->bi', t1, t2)
    assert t3.shape == (Batch, a)
    

    This could also be written using reduce_sum as

    t3 = tf.reduce_sum(t1 * t2[:,None,:], axis=2)
    assert t3.shape == (Batch, a)
    

    Or if you want to do it with the matrix multiplication operator it is possible as well, but in that case you will have to unsqueeze t2 to (Batch, b, 1) (t2[:,:,None]), and squeeze the result back from (Batch, b, 1) to (Batch, b) (t3[:,:,0]).

    t3 = t1 @ t2[:,:,None]
    assert t3[:,:,0].shape == (Batch, a)