Search code examples
pythontensorflowtransfer-learning

Tensorflow session - ValueError: GraphDef cannot be larger than 2GB


I am trying to iterate over dataset batches and run inference on a pre-trained model. I created a session and loaded the model as such:

import numpy as np

sess = tf.Session()

saver = tf.train.import_meta_graph('model_resnet/imagenet.ckpt.meta')
saver.restore(sess, "model_resnet/imagenet.ckpt")

# To view the graph in tensorboard:
summary_writer = tf.summary.FileWriter("/tmp/tensorflow_logdir", graph=tf.get_default_graph())

# To retrieve outputs of layer while inferring
def getActivations(layer,stimuli):
    units = sess.run(layer,feed_dict={"Placeholder_:0": stimuli, keep_prob:1.0})

# Convert to TF Dataset
dataset_train = tf.data.Dataset.from_tensor_slices((X_train, y_train))
dataset_test = tf.data.Dataset.from_tensor_slices((X_test, y_test))

# Create batches
dataset = dataset_train.batch(32)

# Iterator to iterate over images in batch
iterator = dataset.make_one_shot_iterator()

next_element = iterator.get_next()

try:
    getActivations("resnet/pool:0",sess.run(next_element[1]))
except tf.errors.OutOfRangeError:
    print("End of dataset")  # ==> "End of dataset"

I get this error:

ValueError: GraphDef cannot be larger than 2GB.

I maybe misinterpreting the exact meaning of a graph. I don't understand why a single iteration over 32 images would lead to an extension in the graph. Are my operations adding on to the pretrained model graph? According to what I have come across so far, the addition of an operation to a TF Graph is done using add or tf.'function_name', is this correct?

Any help or pointers to examples would be appreciated.

Thanks.


Solution

  • I went over similar questions here and applied a few methods to get rid of the error:

    • Using Placeholders to load data:

      features_placeholder = tf.placeholder(X_train.dtype, X_train.shape)
      labels_placeholder = tf.placeholder(y_train.dtype, y_train.shape)
      dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
      
    • Using with to create a session and loading the graph using a method:

      with tf.Session() as sess:
      
          initialize_iterator(sess, iterator, X_train, y_train)
          next_element = iterator.get_next()
      
          load_graph(sess)