Search code examples
pythonseaborndensity-plot

How to draw a vertical line at the mode of the seaborn distplot


I just learned how to draw a density plot with the seaborn Python module:

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.pyplot import (plot, savefig, xlim, figure,
                              ylim, legend, boxplot, setp,
                              axes, xlabel, ylabel, xticks,
                              axvline)
import seaborn as sns

layer1_G1_G2 = [-0.05567627772688866,
 -0.06829605251550674,
 -0.0721447765827179,
 -0.05942181497812271,
 -0.061410266906023026,
 -0.062010858207941055,
 -0.05238522216677666,
 -0.057129692286252975,
 -0.06323938071727753,
 -0.07018601894378662,
 -0.05972284823656082,
 -0.06124034896492958,
 -0.06971242278814316,
 -0.06730005890130997]

def make_density(layer_list,color, layer_num):

    layer_list_tensor = torch.tensor(layer_list)
    
    # Plot formatting
    plt.title('Density Plot of Median Stn. MC-Losses at Layer ' + layer_num)
    plt.xlabel('MC-Loss')
    plt.ylabel('Density')
    plt.xlim(-0.2,0.05)
    plt.ylim(0, 85)
    min_ylim, max_ylim = plt.ylim()
    
    # Draw the density plot
    sns.distplot(layer_list, hist = False, kde = True,
                 kde_kws = {'linewidth': 2}, color=color)

# plot the density plot
# the resulting density plot is shown below
>>> make_density(layer1_G1_G2, 'green','1')

Image generated from the code above:

How can I draw a vertical line at the mode of this density curve on this distplot?

Thank you,


Solution

  • You can extract the x and y values of the generated curve and find the mode as the highest y-value.

    from matplotlib import pyplot as plt
    import seaborn as sns
    
    layer1_G1_G2 = [-0.05567627772688866, -0.06829605251550674, -0.0721447765827179, -0.05942181497812271, -0.061410266906023026, -0.062010858207941055, -0.05238522216677666, -0.057129692286252975, -0.06323938071727753, -0.07018601894378662, -0.05972284823656082, -0.06124034896492958, -0.06971242278814316, -0.06730005890130997]
    
    def make_density(layer_list, color, layer_num):
        # Draw the density plot
        ax = sns.distplot(layer_list, hist=False, kde=True, kde_kws={'linewidth': 2}, color=color)
        x = ax.lines[0].get_xdata()
        y = ax.lines[0].get_ydata()
        mode_idx = y.argmax()
        ax.vlines(x[mode_idx], 0, y[mode_idx], color='crimson', ls=':')
    
        # Plot formatting
        ax.set_title('Density Plot of Median Stn. MC-Losses at Layer ' + layer_num)
        ax.set_xlabel('MC-Loss')
        ax.set_ylabel('Density')
        ax.autoscale(axis='x', tight=True)
        ax.set_ylim(ymin=0)
    
    make_density(layer1_G1_G2, 'green', '1')
    plt.show()
    

    example plot