Search code examples
pythonkerasdeep-learningtransformer-model

Save trained weights in machine learning code


i have colab for running machine learning model but when it gets 80 epoch my colab ram crashes and i can not go beind the 80 epochs.i want to somebody to help me save the trained weights somewhere and after the ram crashes i start training the model from that epoch.This is my code, how can i write that purpose code in this python code and where?

for comm_round in range(comms_round):

    global_weights = global_model.get_weights()

    scaled_local_weight_list = list()

    client_names= list(clients_batched.keys())
    random.shuffle(client_names)

    for client in client_names:
        local_model = Transformer
        local_model.compile(loss=tf.keras.losses.CategoricalCrossentropy(),
                            optimizer=tf.keras.optimizers.Adam(learning_rate = 0.001),
                            metrics='acc')

        local_model.set_weights(global_weights)

        local_model.fit(clients_batched[client], epochs=1, verbose=0, callbacks=[checkpoint_callback])

        scaling_factor = weight_scalling_factor(clients_batched, client)
        scaled_weights = scale_model_weights(local_model.get_weights(), scaling_factor)
        scaled_local_weight_list.append(scaled_weights)

        K.clear_session()

    average_weights = sum_scaled_weights(scaled_local_weight_list)

    global_model.set_weights(average_weights)

    for(X_test, Y_test) in test_batched:
        global_acc, global_loss = test_model(test_x, test_y, global_model, comm_round + 1)

this code is for final step and for federated learning.


Solution

  • Potential Solution


    Note: I have seen that you have checkpoint_callback inside your fit function which I presume might be a model callback? If so not sure if this would work but it is the general approach that comes to my mind when trying to save off the optimum weights during training when using TensorFlow framework. Since this process is continuous across epochs even if it fails during a late epoch I guess the best runs prior should have been saved off.

    You can use TensorFlow's ModelCheckpoint callback to automatically save your model's weights after each epoch. You can then load these weights to resume training if your runtime crashes. Here's how to set it up:

    1. Define the ModelCheckpoint callback before your training loop:
    checkpoint_filepath = '/path/to/checkpoint/directory'
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        save_weights_only=True,
        monitor='val_loss',  # change this to what you want to monitor
        mode='auto', 
        save_best_only=False,
        verbose=1,
    )
    
    1. Include the ModelCheckpoint callback in your fit function:
    for comm_round in range(comms_round):
        # ... existing code ...
    
        for client in client_names:
            # ... existing code ...
    
            local_model.fit(clients_batched[client], epochs=1, verbose=0, callbacks=[checkpoint_callback])
    
            # ... existing code ...
    

    After each epoch, the model's weights are saved to the specified filepath.

    1. If your runtime crashes and you need to resume training, you can load the saved weights into your model:
    model.load_weights(checkpoint_filepath)
    

    Please note that the path to the directory where you want to save the model weights needs to exist. If you are using Google Colab, it might be a good idea to save your weights in your Google Drive. This way, the saved weights will persist even if the Colab runtime is recycled. To do this, you will need to mount your Google Drive to the Colab notebook.

    Also, please note that this will not automatically save the current communication round. You would need to manage this manually. For example, you can save the current communication round number to a file each time you save your model weights. When you resume training, you can read this number from the file and continue the communication rounds from there.