Search code examples
tensorflowobject-detectiontensorflow2.0object-detection-api

How do I load the two stages of a saved Faster R-CNN separately in TF Object Detection 2.0?


I trained a Faster R-CNN from the TF Object Detection API and saved it using export_inference_graph.py. I have the following directory structure:

weights
|-checkpoint
|-frozen_inference_graph.pb
|-model.ckpt-data-00000-of-00001
|-model.ckpt.index
|-model.ckpt.meta
|-pipeline.config
|-saved_model
|--saved_model.pb
|--variables

I would like to load the first and second stages of the model separately. That is, I would like the following two models:

  1. A model containing each variable in the scope FirstStageFeatureExtractor which accepts an image (or serialized tf.data.Example) as input, and outputs the feature map and RPN proposals.

  2. A model containing each variable in the scopes SecondStageFeatureExtractor and SecondStageBoxPredictor which accepts a feature map and RPN proposals as input, and outputs the bounding box predictions and scores.

I basically want to be able to call _predict_first_stage and _predict_second_stage separately on my input data.

Currently, I only know how to load the entire model:

model = tf.saved_model.load("weights/saved_model")
model = model.signatures["serving_default"]

EDIT 6/7/2020: For Model 1, I may be able to extract detection_features as in this question, but I'm still not sure about Model 2.


Solution

  • This was more difficult when Object Detection was only compatible with TF1, but is now pretty simple in TF2. There's a good example in this colab.

    from object_detection.builders import model_builder
    from object_detection.utils import config_util
    
    # Set path names
    model_name = 'centernet_hg104_512x512_kpts_coco17_tpu-32'
    pipeline_config = os.path.join('models/research/object_detection/configs/tf2/',
                                    model_name + '.config')
    model_dir = 'models/research/object_detection/test_data/checkpoint/'
    
    # Load pipeline config and build a detection model
    configs = config_util.get_configs_from_pipeline_file(pipeline_config)
    model_config = configs['model']
    detection_model = model_builder.build(model_config=model_config, 
                                         is_training=False)
    
    # Restore checkpoint
    ckpt = tf.compat.v2.train.Checkpoint(
          model=detection_model)
    ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()
    

    From here one can call detection_model.predict() and associated methods such as _predict_first_stage and _predict_second_stage.