Search code examples
databricksmlflow

Get programmatically Input Schema of model from MlFlow model registry


Is there a way to fetch the input schema(features on which training was done) from MlFlow model registry ? The input schema is captured using 'signature' parameter when logging the trained model.


Solution

  • I will describe two ways of doing this.

    Model signature can be retrieved from the associated run metadata. Here is a picture showing how to do that in UI:

    enter image description here

    Now, to extract this programmatically, note that logged model metadata is tracked under the mlflow.log-model.history tag. Once we know the corresponding run id (we keep it at hand or query the model store) we can use this code snippet :

    import json
    import mlflow
    from mlflow.client import MlflowClient
    
    client = MlflowClient('http://0.0.0.0:5000')
    run_id = '467677aff0074955a4e75492085d52f9'
    run = client.get_run(run_id)
    log_model_meta = json.loads(run.data.tags['mlflow.log-model.history'])
    log_model_meta[0]['signature']
    

    which agrees with the figure :-)

    {'inputs': '[{"type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1, 4]}}]',
     'outputs': '[{"type": "tensor", "tensor-spec": {"dtype": "int64", "shape": [-1]}}]'}
    

    Another way is to query to the model store. The schema / signature appears under the model view, like below

    enter image description here

    the data can be obtained by the function mlflow.models.get_model_info, like in this snippet

    model_uri = client.get_model_version_download_uri('toy-model','10')
    model_info = mlflow.models.get_model_info(model_uri)
    model_info._signature_dict