Search code examples
pythonazure-machine-learning-servicetaggingmlflow

Setting tags during model logging mlflow


I am logging the model using

mlflow.sklearn.log_model(model, "my-model")

and I want to set tags to the model during logging, I checked that this method does not allow to set tags, there is a mlflow.set_tags() method but it is tagging the run not the model.

Does anyone know how to tag the model during logging?

Thank you!


Solution

  • When using mlflow.sklearn.log_model you work with the experiment registry which is run-focused so only experiments and runs can be described and tagged.

    If you want to set tags on models, you need to work with the model registry.

    The solution I would recommend is to register the model when logging using registered_model_name (there are more fine-grained ways, too) and use MLFlowClient API to set custom properties (like tags) of the already registered model.

    Here is a working example:

    import mlflow
    from mlflow.client import MlflowClient
    
    mlflow.set_tracking_uri('http://0.0.0.0:5000')
    
    experiment_name = 'test_mlflow'
    try:
        experiment_id = mlflow.create_experiment(experiment_name)
    except:
        experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
    
    from sklearn.linear_model import LogisticRegression
    from sklearn.datasets import load_iris
    from sklearn.metrics import accuracy_score
    
    with mlflow.start_run(experiment_id = experiment_id):
        # log performance and register the model
        X, y = load_iris(return_X_y=True)
        params = {"C": 0.1, "random_state": 42}
        mlflow.log_params(params)
        lr = LogisticRegression(**params).fit(X, y)
        y_pred = lr.predict(X)
        mlflow.log_metric("accuracy", accuracy_score(y, y_pred))
        mlflow.sklearn.log_model(lr, 
            artifact_path="models", 
            registered_model_name='test-model'
        )
        # set extra tags on the model
        client = MlflowClient(mlflow.get_tracking_uri())
        model_info = client.get_latest_versions('test-model')[0]
        client.set_model_version_tag(
            name='test-model',
            version=model_info.version,
            key='task',
            value='regression'
        )
    

    Here is the illustration

    tagged and registered model in MLFlow

    See also this excellent documentation of MLFlow Client.