Search code examples
pythonmatplotlibyellowbrick

Yellowbrick learningCurve: chage legend


I want to plot the learning curve of train and validation set using yellowbrick learning curve. I am not using cross validation when plotting the learning curve, but a hold out validation set. Anyway, the legend is fixed to "Cross validation score".

Is there a way to replace it with another string?


Solution

  • Viz returns the axes object in the ax property, you can use it to set the label like that:

    viz.ax.get_lines()[1].set_label('My custom label')
    

    Example:

    import numpy as np
    
    from yellowbrick.datasets import load_energy
    from yellowbrick.model_selection import ValidationCurve
    
    from sklearn.tree import DecisionTreeRegressor
    
    # Load a regression dataset
    X, y = load_energy()
    
    viz = ValidationCurve(
        DecisionTreeRegressor(), param_name="max_depth",
        param_range=np.arange(1, 11), cv=10, scoring="r2"
    )
    
    # Fit and show the visualizer
    viz.fit(X, y)
    viz.ax.get_lines()[1].set_label('My custom label')
    viz.show()
    

    enter image description here