Search code examples
pythontensorflowtensorboard

Is there any way to implement early stopping callback for Tensorflow 2 model_main_tf.py?


Hello I'm working on object detection using tensorflow 2 object detection API model_main_tf2.py file normally we can use early stopping callback for model.fit() when we use normally but when i tried to training with pipeline config model_main_tf2.py file and .config file I'm not able to implement it because I'm unable to locate model.fit() in the main file so please is there any way i can implement the early stopping for model_main_tf2.py file please help me.

Link to the file: https://github.com/tensorflow/models/blob/master/research/object_detection/model_main_tf2.py


Solution

  • I had a look inside the model_main_tf2.py file. Let's take the following piece of code:

    model_lib_v2.train_loop(
              pipeline_config_path=FLAGS.pipeline_config_path,
              model_dir=FLAGS.model_dir,
              train_steps=FLAGS.num_train_steps,
              use_tpu=FLAGS.use_tpu,
              checkpoint_every_n=FLAGS.checkpoint_every_n,
              record_summaries=FLAGS.record_summaries) 
    

    Instead of executing the training through fit it is used a custom training loop. In the code above is called the function that executes the training operation. model_lib_v2 is just another file of the repo that you've linked.

    If you have a look at the train_loop function, you'll see that at some point is executed the following code:

    with tf.GradientTape() as tape:
        losses_dict, _ = _compute_losses_and_predictions_dicts(
            detection_model, features, labels,
            training_step=training_step,
            add_regularization_loss=add_regularization_loss)
    
        losses_dict = normalize_dict(losses_dict, num_replicas)
    
      trainable_variables = detection_model.trainable_variables
    
      total_loss = losses_dict['Loss/total_loss']
      gradients = tape.gradient(total_loss, trainable_variables)
    

    GradientTape basically computes the gradients needed to update the weights of the model during the training phase. I won't go into much details, but if you are interested you can have a look at the linked documentation.

    Now, you are interested in adding an early stopping callback, but you don't have a fit. You can still add early stopping, but in a different way.

    You can follow a strategy like the one below (Refer to this tutorial by tensorflow for the full code):

    epochs = 100
    patience = 5  # you can play with this values to obtain the best config
    wait = 0
    best = 0
    for epoch in range(epochs):
        # training (calling the function that holds the GradientTape
        for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
          loss_value = train_step(x_batch_train, y_batch_train)
        
        # updating the metrics after the whole training loop on a single epoch         
        train_acc = train_acc_metric.result()
        train_loss = train_loss_metric.result()
        train_acc_metric.reset_states()
        train_loss_metric.reset_states()
        print("Training acc over epoch: %.4f" % (train_acc.numpy()))
        
        # evaluating the model just trained in a new epoch, on the validation data
        for x_batch_val, y_batch_val in ds_test:
          test_step(x_batch_val, y_batch_val)
        
        # updating the metrics for validation
        val_acc = val_acc_metric.result()
        val_loss = val_loss_metric.result()
        val_acc_metric.reset_states()
        val_loss_metric.reset_states()
        print("Validation acc: %.4f" % (float(val_acc),))
        print("Time taken: %.2fs" % (time.time() - start_time))
    
        # The early stopping strategy: stop the training if `val_loss` does not
        # decrease over a certain number of epochs.
        wait += 1
        if val_loss > best:
          best = val_loss
          wait = 0
        if wait >= patience:
          break