Search code examples
pythontensorflowmachine-learningtensorflow-lite

How to continue training with checkpoints using object_detector.EfficientDetLite4Spec tensorflow lite


Preciously I have set my EfficientDetLite4 model "grad_checkpoint=true" in config.yaml. And it had successfully generated some checkpoints. However, I can't figure out how to use these checkpoints when I want to continue training based on them.

Every time I train the model it just start from the beginning, not from my checkpoints.

The following picture shows my colab file system structure:

my colab file system structure

The following picture shows where my checkpoints store:

model file system here

The following code shows how I configure the model and how I train with the model.

import numpy as np
import os

from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

train_data, validation_data, test_data = 
    object_detector.DataLoader.from_csv('csv_path')

spec = object_detector.EfficientDetLite4Spec(
    uri='/content/model',
    model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/',
    hparams='grad_checkpoint=true,strategy=gpus',
    epochs=50, batch_size=3,
    steps_per_execution=1, moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25, strategy=spec_strategy
)

model = object_detector.create(train_data, model_spec=spec, batch_size=3, 
    train_whole_model=True, validation_data=validation_data)

Solution

  • The source code is the answer !

    I ran into the same problem and found out that the model_dir we pass to the TFLite model Maker's object detector API is only used for saving the model's weights: that's why the API never restores from checkpoints.

    Having a look at the source code of this API, I noticed it internally uses the standard model.compile and model.fit functions and it saves the model's weights through the callbacks parameter of model.fit.
    This means that, provided that we can get the interal keras model, we can just restore our checkpoints by using model.load_weights !

    These are the links to the source code if you want to know more about what some of the functions I use below do:

    This is the code:

    #Useful imports
    import tensorflow as tf
    from tflite_model_maker.config import QuantizationConfig
    from tflite_model_maker.config import ExportFormat
    from tflite_model_maker import model_spec
    from tflite_model_maker import object_detector
    from tflite_model_maker.object_detector import DataLoader
    
    #Import the same libs that TFLiteModelMaker interally uses
    from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
    from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib
    
    
    
    #Setup variables
    batch_size = 6 #or whatever batch size you want
    epochs = 50
    checkpoint_dir = "/content/..." #whatever your checkpoint directory is
    
    
    
    #Create whichever object detector's spec you want
    spec = object_detector.EfficientDetLite4Spec(
        model_name='efficientdet-lite4',
        uri='https://tfhub.dev/tensorflow/efficientdet/lite4/feature-vector/2', 
        hparams='', #enable grad_checkpoint=True if you want
        model_dir=checkpoint_dir, 
        epochs=epochs, 
        batch_size=batch_size,
        steps_per_execution=1, 
        moving_average_decay=0,
        var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
        tflite_max_detections=25, 
        strategy=None, 
        tpu=None, 
        gcp_project=None,
        tpu_zone=None, 
        use_xla=False, 
        profile=False, 
        debug=False, 
        tf_random_seed=111111,
        verbose=1
    )
    
    
    
    #Load you datasets
    train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv')
    
    
    
    
    #Create the object detector 
    detector = object_detector.create(
        train_data, 
        model_spec=spec, 
        batch_size=batch_size, 
        train_whole_model=True, 
        validation_data=validation_data,
        epochs = epochs,
        do_train = False
    )
    
    
    
    """
    From here on we use internal/"private" functions of the API,
    you can tell because the methods' names begin with an underscore
    """
    
    #Convert the datasets for training
    train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True)
    validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False)
    
    
    
    
    #Get the internal keras model    
    model = detector.create_model()
    
    
    
    
    #Copy what the API internally does as setup
    config = spec.config
    config.update(
        dict(
            steps_per_epoch=steps_per_epoch,
            eval_samples=batch_size * validation_steps,
            val_json_file=val_json_file,
            batch_size=batch_size
        )
    )
    train.setup_model(model, config) #This is the model.compile call basically
    model.summary()
    
    
    
    
    """
    Here we restore the weights
    """
    
    #Load the weights from the latest checkpoint.
    #In my case:
    #checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/" 
    #specific_checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35"
    try:
        #Option A:
        #load the weights from the last successfully completed epoch
        latest = tf.train.latest_checkpoint(checkpoint_dir) 
    
        #Option B:
        #load the weights from a specific checkpoint.
        #Note that there's no ".index" at the end of specific_checkpoint_dir
        #latest = specific_checkpoint_dir
    
        completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted
        model.load_weights(latest)
    
        print("Checkpoint found {}".format(latest))
    except Exception as e:
        print("Checkpoint not found: ", e)
    
    
    
    #Retrieve the needed default callbacks
    all_callbacks = train_lib.get_callbacks(config.as_dict(), validation_ds)
    
    
    
    """
    Optional step.
    Add callbacks that get executed at the end of every N 
    epochs: in this case I want to log the training results to tensorboard.
    """
    #tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1)
    #all_callbacks.append(tensorboard_callback)
    
    
    
    
    """
    Train the model 
    """
    model.fit(
        train_ds,
        epochs=epochs,
        initial_epoch=completed_epochs, 
        steps_per_epoch=steps_per_epoch,
        validation_data=validation_ds,
        validation_steps=validation_steps,
        callbacks=all_callbacks #This is for saving checkpoints at the end of every epoch + running the above added callbacks
    )
    
    
    
    
    """
    Save/export the trained model
    Tip: for integer quantization you simply have to NOT SPECIFY 
    the quantization_config parameter of the detector.export method.
    In this case it would be: 
    detector.export(export_dir = export_dir, tflite_filename='model.tflite')
    """
    export_dir = "/content/..." #save the tflite wherever you want
    quant_config = QuantizationConfig.for_float16() #or whatever quantization you want
    detector.model = model #inject our trained model into the object detector
    detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)