Search code examples
pythontensorflowmachine-learningtraining-datapre-trained-model

How to connect the output tensor of a restored graph to the input of the default graph in tensorflow?


I am new to tensorflow, and I have been stuck at this for several days. Now I have the following pretrained model (4 files):

Classification.inception.model-27.data-0000-pf=00001
Classification.inception.model-27.index
Classification.inception.model-27.meta
checkpoint 

And I can successfully restore this model as a default graph in a new file test.py:

with tf.Session() as sess:
    new_restore = tf.train.import_meta_graph('Classification.inception.model-27.meta')
    new_restore.restore(sess, tf.train.latest_checkpoint('/'))
    graph = tf.get_default_graph()
    input_data = graph.get_tensor_by_name('input_data')
    output = graph.get_tensor_by_name('logits/BiasAdd:0')
    ......
    logits = sess.run(output, feed_dict = {input_data: mybatch})
    ......

The above script works well, because test.py is independent from train.py. So the graph I obtained in this way is just the default one.

However, I don't know how to integrate this pretrained model into an existing graph, i.e. pass the tensor "output" into a new network (python code, instead of restored graph) like this:

def main():
    ### load the meta file and restore the pretrained graph here #####
    new_restore = tf.train.import_meta_graph('Classification.inception.model-27.meta')
    new_restore.restore(sess, tf.train.latest_checkpoint('/'))
    graph = tf.get_default_graph()
    input_data = graph.get_tensor_by_name('input_data')
    output1 = graph.get_tensor_by_name('logits/BiasAdd:0')
    ......
    with tf.Graph().as_default():
        with tf.variable_scope(scope, 'InceptionResnetV1', [inputs], reuse=reuse):
            with slim.arg_scope([slim.batch_norm, slim.dropout], is_training = is_training):
                with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d]):
                    net = slim.conv2d(output1, 32, 3, stride = 2, scope= 'Conv2d_1a_3x3')

However, there comes an error when I pass tensor output1 into slim.conv2d(). The message is:

ValueError:Tensor("InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/AssignMovingAvg:0". shape=(32,). dtype=float32_ref) is not an element of this graph

How do people usual handle this (restore a graph from .meta and connect its output tensor to the input of the current default graph)?

I searched online and found something similar to my problem (i.e. connect input and output tensors of two different graphs tensorflow). But I think it is still quite different.

In addition, there are also some similar approaches which restore ".ckpt" files, but I think they are still not what I am looking for.

Any comments and guidance will be appreciated. Thanks.


Solution

  • Your issue is that with tf.Graph().as_default(): overrides your old graph:

    Another typical usage involves the tf.Graph.as_default context manager, which overrides the current default graph for the lifetime of the context.

    Simply remove this line to keep the old graph like:

    import tensorflow as tf
    import numpy as np
    
    const_input_dummy = np.random.randn(1, 28)
    
    # create graph and save everything
    x = tf.placeholder(dtype=tf.float32, shape=[1, 28], name='plhdr')
    y = tf.layers.dense(x, 2, name='logits')
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(y, {x: const_input_dummy}))
    
        saver = tf.train.Saver()
        saver.save(sess, './export/inception')
    
    # reset everything so far (like creating another script)
    tf.reset_default_graph()
    
    # answer to question
    with tf.Session() as sess:
        # import old graph structure
        restorer = tf.train.import_meta_graph('./export/inception.meta')
        # get reference to tensors from imported graph
        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name("plhdr:0")
        y = graph.get_tensor_by_name('logits/BiasAdd:0')
    
        # add some new operations (and variables)
        with tf.variable_scope('new_scope'):
            y = tf.layers.dense(y, 1, name='other_layer')
    
        # init all variables ...
        sess.run(tf.global_variables_initializer())
        # ... then restore variables from file
        restorer.restore(sess, tf.train.latest_checkpoint('./export'))
    
        # this will execute without errors
        print(sess.run(y, {x: const_input_dummy}))
    

    Usually, there is no need to maintain multiple graphs. So I suggest to just work with a single graph.