Search code examples
pythontensorflowmatrix-multiplication

How to do matrix-scalar multiplication in TensorFlow with batch?


At the beginning, I will describe formula, what I am trying to compute:

Math formula on google chat api (I can't post image directly.)

where I is identity matrix with shape (M,M), N_i is the vector (C) and T is the matrix (C*F,M), T_c are submatrices with shape (F,M).

My code for tensorflow to enumerate this look like that:

N_p = tf.placeholder(floatX, shape=[C], name='N_p')
I = tf.Variable(np.eye(M),dtype=tf.float32, name="I")
T = tf.Variable(np.random.rand(C*F,M),dtype=tf.float32, name="T")

L = I
for i,T_c in enumerate([T[i:i+F,:] for i in xrange(0,F*C,F)]):
    L=tf.add(L,tf.scalar_mul(N_p[i],tf.matmul(tf.transpose(T_c),T_c)))

This works fine, unfortunately, I need expand this into batch processing, here N_p will be:

N_p = tf.placeholder(floatX, shape=[None,C], name='N_p')

Unfortunately, I don't know hor change my tensorflow formula. Problem is in scalar_mul.

L=tf.add(L,tf.scalar_mul(N_p[:,i],tf.matmul(tf.transpose(T_c),T_c)))

what is obvious why, but how to rewrite it? Thanks a lot for any advice.


Solution

  • You can achieve the above in matrix form, without any loops:

    T --> [C, F, M]
    T_1 --> transpose T to --> [C, F, M]
    T_2 --> transpose T to --> [C, M, F]
    d --> matmul (T_1, T_2) --> [C, M, M] --> transpose --> [M, M, C]
    out --> multiply (d, N) : d -> [1, M, M, C], N -> [batch, 1, 1, C]
          --> [batch, M, M, C] --> reduce_sum (axis=2) --> [batch, M, M]
          --> add I
    

    The working code (matches your code for batch=1):

    N_1 = tf.placeholder(tf.float32, [None, C])
    reshape_T = tf.reshape(T, [C, F, M])
    
    # reshape to do a batch matrix multiplication (C, F, M) and (C, M, F)
    T_1 = tf.transpose(reshape_T, [0, 2, 1])
    T_2 = tf.transpose(reshape_T, [0, 1, 2])
    
    d = tf.transpose(tf.matmul(T_1,T_2), [2,1,0])
    out = tf.reduce_sum(d[None,...]* tf.reshape(N_1, [-1, 1, 1, C]), axis=3) + I
    
    with tf.Session() as sess:
       sess.run(tf.global_variables_initializer())
       print(sess.run(out, {N_1: inp}))