Search code examples
python-3.xtensorflowkerasdeep-learninggenerative-adversarial-network

Retrain a saved model in Keras that was trained using train_on_batch()


I am working on GANS and I need to save the model after my working hours. And then I have to retrain that previously saved model again where it was saved. I am saving these three models to continue training later on.

Discriminator Model.h5
Generator Model.h5
Generator-on-Discriminator Model.h5

For these models, I am using perceptual loss and Wasserstein loss. But when I load_model to retrain that saved model again it encounters the following error.

Unknown loss function:wasserstein_loss

I have also tried Discriminator.compile(loss=Wasserstein loss) but this still not solving my issue. Can anyone of you please guide me over this and can tell me wither its possible to retrain a saved model using train_on_batch().


Solution

  • solved at my own

    Defining custom_objects={'wassertein_loss':wassertein_loss} along with path while loading the model solved my issue. i.e.

    Discriminator=load_model(model_path, custom_objects={'wassertein_loss':wassertein_loss} )