Search code examples
kerasdatabricksazure-databricks

Unable to save keras model in databricks


I am saving keras model

model.save('model.h5')

in databricks, but model is not saving,

I have also tried saving in /tmp/model.h5 as mentioned here

but model is not saving.

The saving cell executes but when I load model it shows no model.h5 file is available.

when I do this dbfs_model_path = 'dbfs:/FileStore/models/model.h5' dbutils.fs.cp('file:/tmp/model.h5', dbfs_model_path)

OR try loading model

tf.keras.models.load_model("file:/tmp/model.h5")

I get error message java.io.FileNotFoundException: File file:/tmp/model.h5 does not exist


Solution

  • The problem is that Keras is designed to work only with local files, so it doesn't understand URIs, such as dbfs:/, or file:/. So you need to use local paths for saving & loading operations, and then copy files to/from DBFS (unfortunately /dbfs doesn't play well with Keras because of the way it works).

    The following code works just fine. Note that dbfs:/ or file:/ are used only in the calls to the dbutils.fs commands - Keras stuff uses the names of local files.

    • create model & save locally as /tmp/model-full.h5:
    from tensorflow.keras.applications import InceptionV3
    model = InceptionV3(weights="imagenet")
    model.save('/tmp/model-full.h5')
    
    • copy data to DBFS as dbfs:/tmp/model-full.h5 and check it:
    dbutils.fs.cp("file:/tmp/model-full.h5", "dbfs:/tmp/model-full.h5")
    display(dbutils.fs.ls("/tmp/model-full.h5"))
    
    • copy file from DBFS as /tmp/model-full2.h5 & load it:
    dbutils.fs.cp("dbfs:/tmp/model-full.h5", "file:/tmp/model-full2.h5")
    from tensorflow import keras
    model2 = keras.models.load_model("/tmp/model-full2.h5")