Search code examples
tensorflowmachine-learningbackpropagationtensorflow-gradient

How to backprop through a model that predicts the weights for another in Tensorflow


I am currently trying to train a model (hypernetwork) that can predict the weights for another model (main network) such that the main network's cross-entropy loss decreases. However when I use tf.assign to assign the new weights to the network it does not allow backpropagation into the hypernetwork thus rendering the system non-differentiable. I have tested whether my weights are properly updated and they seem to be since when subtracting initial weights from updated ones is a non zero sum.

This is a minimal sample of what I am trying to achieve.

import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import softmax

def random_addition(variables):
     addition_update_ops = []
     for variable in variables:
          update = tf.assign(variable, variable+tf.random_normal(shape=variable.get_shape()))
          addition_update_ops.append(update)
     return addition_update_ops


def network_predicted_addition(variables, network_preds):
     addition_update_ops = []
     for idx, variable in enumerate(variables):
          if idx == 0:
               print(variable)
               update = tf.assign(variable, variable + network_preds[idx])
               addition_update_ops.append(update)
     return addition_update_ops

def dense_weight_update_net(inputs, reuse):
     with tf.variable_scope("weight_net", reuse=reuse):
          output = tf.layers.conv2d(inputs=inputs, kernel_size=(3, 3), filters=16, strides=(1, 1),
                                      activation=tf.nn.leaky_relu, name="conv_layer_0", padding="SAME")
          output = tf.reduce_mean(output, axis=[0, 1, 2])
          output = tf.reshape(output, shape=(1, output.get_shape()[0]))
          output = tf.layers.dense(output, units=(16*3*3*3))
          output = tf.reshape(output, shape=(3, 3, 3, 16))
     return output

def conv_net(inputs, reuse):
     with tf.variable_scope("conv_net", reuse=reuse):
          output = tf.layers.conv2d(inputs=inputs, kernel_size=(3, 3), filters=16, strides=(1, 1),
                                      activation=tf.nn.leaky_relu, name="conv_layer_0", padding="SAME")
          output = tf.reduce_mean(output, axis=[1, 2])
          output = tf.layers.dense(output, units=2)
          output = softmax(output)
     return output

input_x_0 = tf.zeros(shape=(32, 32, 32, 3))
target_y_0 = tf.zeros(shape=(32), dtype=tf.int32)
input_x_1 = tf.ones(shape=(32, 32, 32, 3))
target_y_1 = tf.ones(shape=(32), dtype=tf.int32)
input_x = tf.concat([input_x_0, input_x_1], axis=0)
target_y = tf.concat([target_y_0, target_y_1], axis=0)

output_0 = conv_net(inputs=input_x, reuse=False)

target_y = tf.one_hot(target_y, 2)

crossentropy_loss_0 = tf.losses.softmax_cross_entropy(onehot_labels=target_y, logits=output_0)


conv_net_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="conv_net")
weight_net_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="weight_net")
print(conv_net_parameters)
weight_updates = dense_weight_update_net(inputs=input_x, reuse=False)
#updates_0 = random_addition(conv_net_parameters)
updates_1 = network_predicted_addition(conv_net_parameters, network_preds=[weight_updates])
with tf.control_dependencies(updates_1):
     output_1 = conv_net(inputs=input_x, reuse=True)
     crossentropy_loss_1 = tf.losses.softmax_cross_entropy(onehot_labels=target_y, logits=output_1)
     check_sum = tf.reduce_sum(tf.abs(output_0 - output_1))


c_opt = tf.train.AdamOptimizer(beta1=0.9, learning_rate=0.001)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  # Needed for correct batch norm usage
with tf.control_dependencies(update_ops):  # Needed for correct batch norm usage
     train_variables = weight_net_parameters #+ conv_net_parameters

     c_error_opt_op = c_opt.minimize(crossentropy_loss_1,
                                     var_list=train_variables,
                                     colocate_gradients_with_ops=True)


init=tf.global_variables_initializer()

with tf.Session() as sess:
     init = sess.run(init)
     loss_list_0 = []
     loss_list_1 = []
     for i in range(1000):
          _, checksum, crossentropy_0, crossentropy_1 = sess.run([c_error_opt_op, check_sum, crossentropy_loss_0,
                                                                  crossentropy_loss_1])
          loss_list_0.append(crossentropy_0)
          loss_list_1.append(crossentropy_1)
          print(checksum, np.mean(loss_list_0), np.mean(loss_list_1))

Does anyone know how I can get tensorflow to compute the gradients for this? Thank you.


Solution

  • In this case your weights aren't variables, they are computed tensors based on the hypernetwork. All you really have is one network during training. If I understand you correctly you are then proposing to discard the hypernetwork and be able to use just the main network to perform predictions.

    If this is the case then you can either save the weight values manually and reload them as constants, or you could use tf.cond and tf.assign to assign them as you are doing during training, but use tf.cond to choose to use the variable or the computed tensor depending on whether you're doing training or inference.

    During training you will need to use the computed tensor from the hypernetwork in order to enable backprop.


    Example from comments, w is the weight you'll use, you can assign a variable during training to keep track of it, but then use tf.cond to either use the variable (during inference) or the computed value from the hypernetwork (during training). In this example you need to pass in a boolean placeholder is_training_placeholder to indicate if you're running training of inference.

    tf.assign(w_variable, w_from_hypernetwork)
    w = tf.cond(is_training_placeholder, true_fn=lambda: w_from_hypernetwork, false_fn=lambda: w_variable)