Search code examples
python-3.xtensorflowmatrixdeep-learningloss-function

How to assemble scalars into a matrix in tensorflow?


In short, I want to assemble scalars w_ij into symmetric matrix W like so:

W[i, j] = w_ij
W[j, i] = w_ij

After struggling with this and looking up material on the internet and here on SE, I can't find a way to construct matrix W from w_ij's, I'm lost as to how to do this. Any help would be appreciated.

Elaborations and MWE are below.



The problem

In my research I am trying to train a network that maps source to a scalar w_ij. Where the output w_ij is intended to represent an element i,j in a symmetric matrix W.

So, the loss for training is constructed by assembling the outputs of many identical networks (with shared weights but each seeing a different input, and driving a different element in a matrix) into a matrix form, like so:

W[i, j] = w_ij
W[j, i] = w_ij

and then training those multiple networks on a loss of the form:

L2_loss(f(W) - f(True_W))

(Where f() is a function that runs f(Y) = d' Y d a quadratic form --- product of matrix by a fixed vector from left and right.)

I need to run gradients through this loss to each network.



What I tried

  1. A naive tensor slicing is not supported on tensorflow, i.e.,

    W[i, j] = w_ij is not supported.

  2. Using tf.scatter_update() does not allow to run gradients through it.

  3. Finally, I though I was close to a solution, I tried using tf.Variable for matrix W, like below:

     W_flat = tf.Variable(initial_value=[0] * (2 * 2), dtype='float32')
    

    and then assign to this W_flat by slicing itW_falt[0].assign(w_ij), but it seems that my assignments to this variable do not work (see MWE).


MWE

Bellow is a short MWE, where W is a 2-by-2 symmetric matrix with zero diagonal, so I only have one independent element that a network has to drive (so here I only have a single network), i.e., I would like to get W to have the values

W =   [[0, w_ij] [w_ij, 0]]

So I try update:

W_flat[1].assign(w_ij)
W_flat[2].assign(w_ij)

And turn it back into a matrix:

W = tf.reshape(W_flat, (2, 2))

Eventually this update does not go through, the output of print shows that W remains all zeros.

The code

import tensorflow as tf

def train():

    with tf.Graph().as_default():
        with tf.device('/cpu'):
            source = tf.placeholder(tf.float32, shape=(2, 3))
            is_training = tf.placeholder(tf.bool, shape=())

            w_ij = tf.reduce_sum(source)

            W_flat = tf.Variable(initial_value=[0] * (2 * 2), dtype='float32')

            W_flat[1].assign(w_ij)
            W_flat[2].assign(w_ij)
            tf.assign(W_flat[1], w_ij)
            tf.assign(W_flat[2], w_ij)

            W = tf.reshape(W_flat, (2, 2))

        sess = tf.Session()
        init = tf.global_variables_initializer()
        sess.run(init, {is_training: True})

        ops = {'W_flat': W_flat,
               'source' : source,
               'w_ij' : w_ij,
               'W' : W}

        for epoch in range(2):
            feed_dict = {ops['source']: [[1,1,1], [7,7,7]]}
            res_W_flat, res_wij, res_W = sess.run([ops['W_flat'], ops['w_ij'], ops['W']], feed_dict=feed_dict)
            print("epoch:" ,  epoch)
            print("W_flat:", res_W_flat)
            print("wij:", res_wij)
            print("W:", res_W)

if __name__ == "__main__" :
    train()

print() outputs

epoch: 0
W_flat: [0. 0. 0. 0.]
wij: 24.0
W: [[0. 0.]
 [0. 0.]]
epoch: 1
W_flat: [0. 0. 0. 0.]
wij: 24.0
W: [[0. 0.]
 [0. 0.]]

So W and W_flat is not updated by the value of w_ij, the value of which is 24 but W and W_flat remain zeros.


Solution

  • The solution that I found after more struggles is to use tf.scatter_nd() to update the matrix W, unlike tf.scatter_update(), tf.scatter_nd() produces tensor which does support gradients propagation from its input to its output.

    So instead of writing

            W_flat = tf.Variable(initial_value=[0] * (2 * 2), dtype='float32')
    
            W_flat[1].assign(w_ij)
            W_flat[2].assign(w_ij)
    
            W = tf.reshape(W_flat, (2, 2))
    

    It worked using:

            W_flat = tf.Variable(initial_value=[0] * (2 * 2), dtype='float32')
            indices = tf.constant([[1], [2]])
            shape = tf.constant([4])
            W_flat = tf.scatter_nd(indices, w_ij, shape)
    
            W = tf.reshape(W_flat, (2, 2))