Search code examples
tensorflowobjectbatch-processingpredictiondetection

Batch prediction using a trained Object Detection APIs model and TF 2


I successfully trained a model using Object Detection APIs for TF 2 on TPUs which is saved as a .pb (SavedModel format). I then load it back using tf.saved_model.load and it works fine when predicting boxes using a single image converted to a tensor with shape (1, w, h, 3).

import tensorflow as tf
import numpy as np

# Load Object Detection APIs model
detect_fn = tf.saved_model.load('/path/to/saved_model/')

image = tf.io.read_file(image_path)
image_np = tf.image.decode_jpeg(image, channels=3).numpy()
input_tensor = np.expand_dims(image_np, 0)
detections = detect_fn(input_tensor) # This works fine

Problem is I need to make this a batch prediction to scale it to half a million images, but the input signature of this model seems to be limited to handling only data with shape (1, w, h, 3). This also means that I can't use batch processing with Tensorflow Serving. How can I solve this problem? Can I merely change the model Signature to handle batches of data?

All work (loading model + predictions) was performed inside the official container released with the Object Detection APIs (from here)


Solution

  • I have met this issue recently. When you use exporter_main_v2.py to convert checkpoint files to .pb file, it will call exporter_lib_v2.py. I figured that in file exporter_lib_v2.py (here), TF2 hard fixed the input signature with shape [1, None, None, 3]. We have to change it to [None, None, None, 3]

    Need to modify those lines in that file (138, 162, 170, 185) from 1 to None. Then rebuild the TF2 Object Detector API Repo (link) and use new built version to export .pb again.