Search code examples
pythonmatplotliblegendsubplotexploratory-data-analysis

Legend in subplots() for vertical lines Matplotlib.pyplot Python


I am trying to do EDA with the Kaggle dataset link

I made a plot with 3 subplots and have plotted 3 vertical lines on the basis of mean, median and mode. is there any way to show these 3 lines in a legend?

This is my code

def plott(data):
    fig, axes = plt.subplots(3, sharex=True, figsize=(15, 15),gridspec_kw={"height_ratios": (1, 0.2, 0.6)})
    fig.suptitle('Spread of Data for ' + data.name, fontsize=20, fontweight='bold')
    
    sns.histplot(data, kde=True, binwidth=1, ax=axes[0])
    sns.boxplot(x=data, orient='h', ax=axes[1])
    sns.violinplot(x=data, ax=axes[2])
    
    axes[0].set_xlabel('')
    axes[1].set_xlabel('')
    axes[2].set_xlabel('')

    axes[0].axvline(data.mean(), color='r', linewidth=2, linestyle='solid')
    axes[0].axvline(data.median(), color='r', linewidth=2, linestyle='dashed')
    axes[0].axvline(data.mode()[0], color='r', linewidth=2, linestyle='dotted')
    axes[1].axvline(data.mean(), color='r', linewidth=2, linestyle='solid')
    axes[1].axvline(data.median(), color='r', linewidth=2, linestyle='dashed')
    axes[1].axvline(data.mode()[0], color='r', linewidth=2, linestyle='dotted')
    axes[2].axvline(data.mean(), color='r', linewidth=2, linestyle='solid')
    axes[2].axvline(data.median(), color='r', linewidth=2, linestyle='dashed')
    axes[2].axvline(data.mode()[0], color='r', linewidth=2, linestyle='dotted')
    
    axes[0].tick_params(axis='both', which='both', labelsize=10, labelbottom=True)
    axes[1].tick_params(axis='both', which='both', labelsize=10, labelbottom=True)
    axes[2].tick_params(axis='both', which='both', labelsize=10, labelbottom=True)
    
    


plott(df['Age'])

This is the resulting plot

enter image description here Is there a way to add the legend in here in accordance to the 3 vertical lines like this with each line type denoting the value?

enter image description here

Also, how to add more values in x axis of all three graphs? like make it interval of 5 or 2 years apart?

Thanks


Solution

  • Give the axvlines a "label" value, then call plt.legend after plotting it.

    Example:

    import matplotlib.pyplot as plt
    
    plt.plot([1,2,3],[1,2,3],label="Test")
    plt.axvline(x=0.22058956, label="Test2", color="red")
    plt.legend()
    

    Output:enter image description here