Search code examples
pythontensorflowtensorflow-estimator

Tensorflow estimator: predict without loading from checkpoint everytime


I am using estimator in Tensorflow(1.8) and python3.6 to build up Neural network for my Reinforcement learning project. And I notice everytime you use estimator.predict(), the tensorflow will load up the checkpoint under the model_dir. But It's extremely inefficient if you have to use this function multiple times for the same checkpoint, e.g. in Reinforcement learning, I may need to predict next action based on current state and next state will be realized only after you choose a specific action. So it's commonplace to call this function thousands of times.

So my question is, how to call this function without loading checkpoint(same checkpoint) everytime.

Thank you.


Solution

  • Well, I think I've just found a good answer to my own question. A good solution to this problem is to construct a tf.dataset by generator. Link is here.

    Generator will keep your estimator.predict open and in this way you won't need to keep loading the checkpoint. Only thing you need to do is to change the yielded object in this fastpredict object(self.next_feature in this case) if necessary.

    However, I need to mention if your ultimate goal is to make the whole thing a service or something like that. You may need something like tf.serving. So I suggest you go in that way directly. I waste a lot time in the process. So I hope this answer helps you save yours.