Search code examples
tensorflowtensorflow-datasetstensorflow-estimator

Initialization of tf.contrib.data.Iterator with tf.estimator


How one should initialize tf.contrib.data.Iterator in case tf.estimator.Estimator also is used?

One of the problems is that input graph (the part of tf graph handling input) supposed to be defined in intput_fn() - beacause tf.estimator creates seprate graph.

This requirement makes it hard to access the iterator init ops and pass them to tf.estimator (passing the ops can be done when calling train/evaluate/predict in forms of hooks).


Solution

  • One option is to wrap your input_fn inside another function that sets up a simple SessionRunHook init_hook. All the ops are defined inside input_fn, which gets called in the same graph as the rest of your model, but from it you can set the iterator_init_op as an attribute on init_hook.

    def get_input_fn(mode="train"):
        init_hook = IteratorInitHook()
    
        def input_fn():
            ...
            iterator = dataset.make_initializable_iterator()
            init_hook.iterator_init_op = iterator.initializer
    
        return input_fn, init_hook
    
    class IteratorInitHook(tf.train.SessionRunHook):
    
        def after_create_session(self, session, coord):
            session.run(self.iterator_init_op)
    

    Now when constructing an Experiment, you can get these input functions and and init hooks, which get called when train/eval sessions are created. It should work equivalently with estimator.train.

    train_input_fn, train_init_hook = get_input_fn("train")
    test_input_fn, test_init_hook = get_input_fn("test")
    
    return tf.contrib.learn.Experiment(
        estimator=estimator,
        train_input_fn=train_input_fn,
        eval_input_fn=test_input_fn,
        train_monitors=[train_init_hook],
        eval_hooks=[test_init_hook],
    )