I am trying to feed minibatches of numpy arrays to my model, but I'm stuck with batching. Using 'tf.train.shuffle_batch' raises an error because the 'images' array is larger than 2 GB. I tried to go around it and create placeholders, but when I try to feed the the arrays they are still represented by tf.Tensor objects. My main concern is that I defined the operations under the model class and the objects don't get called before running the session. Does anyone have an idea how to handle this issue?
def main(mode, steps):
config = Configuration(mode, steps)
if config.TRAIN_MODE:
images, labels = read_data(config.simID)
assert images.shape[0] == labels.shape[0]
images_placeholder = tf.placeholder(images.dtype,
images.shape)
labels_placeholder = tf.placeholder(labels.dtype,
labels.shape)
dataset = tf.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
# shuffle
dataset = dataset.shuffle(buffer_size=1000)
# batch
dataset = dataset.batch(batch_size=config.batch_size)
iterator = dataset.make_initializable_iterator()
image, label = iterator.get_next()
model = Model(config, image, label)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
# ...
for step in xrange(steps):
sess.run(model.optimize)
You are using the initializable iterator of tf.Data
to feed data to your model. This means that you can parametrize the dataset in terms of placeholders, and then call an initializer op for the iterator to prepare it for use.
In case you use the initializable iterator, or any other iterator from tf.Data
to feed inputs to your model, you should not use the feed_dict
argument of sess.run
to try to do data feeding. Instead, define your model in terms of the outputs of iterator.get_next()
and omit the feed_dict
from sess.run
.
Something along these lines:
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()
# use get_next outputs to define model
model = Model(config, image_batch, label_batch)
# placeholders fed in while initializing the iterator
sess.run(iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
for step in xrange(steps):
# iterator will feed image and label in the background
sess.run(model.optimize)
The iterator will feed data to your model in the background, additional feeding via feed_dict
is not necessary.