Search code examples
pythonmatplotlibgraphlabel

How to get distinct label name for each scatter points in graph using matplotlib in python?


I am using a for loop for giving labels & color to points. But I am not getting a unique label for each different color symbol, instead, I am getting a long queue at the top of my graph due to large excel data of 500 or so rows. Here is my code:

for val in ghi:
if val < 2:
    col.append('navy')
    lab.append(' <2 ')
elif (val >= 2) and (val < 4):
    col.append('lightblue')
    lab.append(' 2~4 ')
elif (val >= 4) and (val < 6):
    col.append('orange')
    lab.append(' 4~6 ')
else:
    col.append('brown')
    lab.append(' >6 ')
plt.scatter(x, y, marker = "D", s = 10, color = col, label = lab)
sns.lineplot(x, y, color ="red")
plt.legend()

Solution

  • The reason you are seeing multiple labels is because of label=lab. There is one entry for each row. So, you will need to use plt.legend() to add the handles and labels, so that you can have just four entries as in the if/else statement. Also, to ensure the colors and labels match, sort the data by the lab field.

    As there wasn't any data provided the first few lines include creation of random data. Note that I have the data into a dataframe and sorted it by ghi field. Later, the legend has the handles and labels that you can customize as you need. Hope this is what you are trying to achieve.

    Otherwise, I have tried to keep the same names and as much of your code as is, so that it is easier for you to interpret.

    import seaborn as sns
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    from matplotlib.lines import Line2D 
    
    ghi = np.random.randint(11, size=(20))
    x = np.random.rand(20)
    y = np.random.rand(20)
    col = []
    for val in ghi:
        if val < 2:
            col.append('navy')
        elif (val >= 2) and (val < 4):
            col.append('lightblue')
        elif (val >= 4) and (val < 6):
            col.append('orange')
        else:
            col.append('brown')
    
    df = pd.DataFrame({'x':x, 'y':y, 'ghi':ghi, 'col':col})
    df.sort_values(by=['ghi'], inplace = True)
    plt.scatter(df.x, df.y, marker = "D", s = 50, color = df.col)#, label = np.unique(lab))
    sns.lineplot(x, y, color ="red")
    myHandles = [Line2D([], [], marker='D', color='navy', linestyle='None'),
              Line2D([], [], marker='D', color='lightblue', linestyle='None'),
              Line2D([], [], marker='D', color='orange', linestyle='None'),
              Line2D([], [], marker='D', color='brown', linestyle='None')]
    
    plt.legend(handles=myHandles, labels = [' <2 ', ' 2~4 ', ' 4~6 ', ' >6 '])
    

    Output plot

    enter image description here