Search code examples
tensorflowmachine-learningdeep-learningobject-detection

Tensorflow object_detection correct way to save and load fine tune model


I'm using this example from the colabs tutorial to fine tune a model, after training I want to save the model and load on my local computer using:

ckpt_manager = tf.train.CheckpointManager(ckpt, directory="test_data/checkpoint/", max_to_keep=5)
...
...
print('Done fine-tuning!')

ckpt_manager.save()
print('Checkpoint saved!')

but after restore on my local computer using the checkpoint files doesn't detect any object (the scores are too low)

Also I have tried with

tf.saved_model.save(detection_model, '/content/new_model/')

And load with this:

detection_model = tf.saved_model.load('/saved_model_20201226/')

input_tensor = tf.convert_to_tensor(image, dtype=tf.float32)
detections = detection_model(input_tensor)

Give me this error: TypeError: '_UserObject' object is not callable

What is the correct way to save and load a fine tuned model?

EDIT 1: It was pending to save the new pipeline config, after that finally worked! This is my answer:

# Save new pipeline config
new_pipeline_proto = config_util.create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(new_pipeline_proto, '/content/new_config')
exported_ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt_manager = tf.train.CheckpointManager(
exported_ckpt, directory="test_data/checkpoint/", max_to_keep=5)
...
...
print('Done fine-tuning!')

ckpt_manager.save()
print('Checkpoint saved!')

Solution

  • It was pending to save the new pipeline config, after that finally worked! This is my answer:

    # Save new pipeline config
    new_pipeline_proto = config_util.create_pipeline_proto_from_configs(configs)
    config_util.save_pipeline_config(new_pipeline_proto, '/content/new_config')
    
    exported_ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
    ckpt_manager = tf.train.CheckpointManager(
    exported_ckpt, directory="test_data/checkpoint/", max_to_keep=5)
    ...
    ...
    print('Done fine-tuning!')
    
    ckpt_manager.save()
    print('Checkpoint saved!')