Search code examples
pythonmatplotlibseabornbar-chartlegend

add legend seaborn barplot


import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

array = np.array([[1,5,9],[3,5,7]])

df = pd.DataFrame(data=array, index=['Positive', 'Negative'])

f, ax = plt.subplots(figsize=(8, 6))

current_palette = sns.color_palette('colorblind')

ax_pos = sns.barplot(x = np.arange(0,3,1), y = df.loc['Positive'].to_numpy(), color = current_palette[2], alpha = 0.66)
ax_neg = sns.barplot(x = np.arange(0,3,1), y = df.loc['Negative'].to_numpy(), color = current_palette[4], alpha = 0.66)

plt.xticks(np.arange(0,3,1), fontsize = 20)
plt.yticks(np.arange(0,10,1), fontsize = 20)

plt.legend((ax_pos[0], ax_neg[0]), ('Positive', 'Negative'))

plt.tight_layout()

Which produces the follow error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[32], line 15
     12 plt.xticks(np.arange(0,3,1), fontsize = 20)
     13 plt.yticks(np.arange(0,10,1), fontsize = 20)
---> 15 plt.legend((ax_pos[0], ax_neg[0]), ('Positive', 'Negative'))
     17 plt.tight_layout()

TypeError: 'Axes' object is not subscriptable

enter image description here

I would like to know why calling legend like this (plt.legend(ax[0]...) is not possible with seaborn, whereas with matplotlib it is.

In the end, I just want the legend in the upper left corner.


Solution

  • I figured out that barplot has "label" function :

    import numpy as np
    import pandas as pd
    import seaborn as sns
    import matplotlib.pyplot as plt
    
    array = np.array([[1,5,9],[3,5,7]])
    
    df = pd.DataFrame(data=array, index=['Positive', 'Negative'])
    
    f, ax = plt.subplots(figsize=(8, 6))
    
    current_palette = sns.color_palette('colorblind')
    
    sns.barplot(x = np.arange(0,3,1), y = df.loc['Positive'].to_numpy(), color = current_palette[2], alpha = 0.66, label = "Positive")
    sns.barplot(x = np.arange(0,3,1), y = df.loc['Negative'].to_numpy(), color = current_palette[4], alpha = 0.66, label = "Negative")
    
    plt.xticks(np.arange(0,3,1), fontsize = 20)
    plt.yticks(np.arange(0,10,1), fontsize = 20)
    
    plt.legend(frameon = False)
    
    plt.tight_layout()
    

    enter image description here