How one should initialize tf.contrib.data.Iterator
in case tf.estimator.Estimator
also is used?
One of the problems is that input graph (the part of tf graph handling input) supposed to be defined in intput_fn()
- beacause tf.estimator creates seprate graph.
This requirement makes it hard to access the iterator init ops
and pass them to tf.estimator
(passing the ops can be done when calling train/evaluate/predict
in forms of hooks).
One option is to wrap your input_fn
inside another function that sets up a simple SessionRunHook init_hook
. All the ops are defined inside input_fn
, which gets called in the same graph as the rest of your model, but from it you can set the iterator_init_op
as an attribute on init_hook
.
def get_input_fn(mode="train"):
init_hook = IteratorInitHook()
def input_fn():
...
iterator = dataset.make_initializable_iterator()
init_hook.iterator_init_op = iterator.initializer
return input_fn, init_hook
class IteratorInitHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord):
session.run(self.iterator_init_op)
Now when constructing an Experiment
, you can get these input functions and and init hooks, which get called when train/eval sessions are created. It should work equivalently with estimator.train
.
train_input_fn, train_init_hook = get_input_fn("train")
test_input_fn, test_init_hook = get_input_fn("test")
return tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=test_input_fn,
train_monitors=[train_init_hook],
eval_hooks=[test_init_hook],
)