Search code examples
tensorflow2.x

How do I replace keras.utils.multi_gpu_model for training with multiple gpus


As of tensorflow 2.4 tensorflow.keras.utils.multi_gpu_model has been removed. I am looking for a way to replace this simple command to train with multiple gpus.

from tensorflow.keras.models import load_model

model = load_model("my_model.h5")
if gpus>1:
    from tensorflow.keras.utils import multi_gpu_model
    model = multi_gpu_model(model, gpus=gpus)

Where model is a loaded model that can be used to train or make predictions on multiple gpus.


Solution

  • One way to train with multiple gpus is to use a distributed strategy. The way I found that works pretty much as a drop in replacement is a MirroredStrategy

    session = MirroredStrategy()
    with session.scope():
        model = load_model("my_model.h5")
    

    This way when the model is used, inside of this block, it is used on multiple gpus.