Search code examples
pythonazuredatabricksdatastoremlflow

How can I retrive the model.pkl in the experiment in Databricks


I want to retrieve the pickle off my trained model, which I know is in the run file inside my experiments in Databricks.

It seems that the mlflow.pyfunc.load_model can only do the predict method.

There is an option to directly access the pickle?

I also tried to use the path in the run using the pickle.load(path) (example of path: dbfs:/databricks/mlflow-tracking/20526156406/92f3ec23bf614c9d934dd0195/artifacts/model/model.pkl).


Solution

  • I recently found the solution which can be done by the following two approaches:

    1. Use the customized predict function at the moment of saving the model (check databricks documentation for more details).

    example give by Databricks

    class AddN(mlflow.pyfunc.PythonModel):
    
        def __init__(self, n):
            self.n = n
    
        def predict(self, context, model_input):
            return model_input.apply(lambda column: column + self.n)
    # Construct and save the model
    model_path = "add_n_model"
    add5_model = AddN(n=5)
    mlflow.pyfunc.save_model(path=model_path, python_model=add5_model)
    
    # Load the model in `python_function` format
    loaded_model = mlflow.pyfunc.load_model(model_path)
    
    1. Load the model artefacts as we are downloading the artefact:
    from mlflow.tracking import MlflowClient
    
    client = MlflowClient()
    
    tmp_path = client.download_artifacts(run_id="0c7946c81fb64952bc8ccb3c7c66bca3", path='model/model.pkl')
    
    f = open(tmp_path,'rb')
    
    model = pickle.load(f)
    
    f.close()
    
     
    
    client.list_artifacts(run_id="0c7946c81fb64952bc8ccb3c7c66bca3", path="")
    
    client.list_artifacts(run_id="0c7946c81fb64952bc8ccb3c7c66bca3", path="model")