Search code examples
matplotlibplotlegendscatter-plot

Add a legend to a figure


Here is a code

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
import numpy as np
 
fig, subs = plt.subplots(4,3) #setting the shape of the figure in one line as opposed to creating 12 variables 
 
iris = load_iris() ##code as per the example 
data = np.array(iris['data'])
targets = np.array(iris['target'])
 
cd = {0:'r',1:'b',2:"g"}
cols = np.array([cd[target] for target in targets])

 
# Row 1 
 
subs[0][0].scatter(data[:,0], data[:,1], c=cols)
subs[0][1].scatter(data[:,0], data[:,2], c=cols)
subs[0][2].scatter(data[:,0], data[:,3], c=cols)
 
# Row 2 
 
subs[1][0].scatter(data[:,1], data[:,0], c=cols)
subs[1][1].scatter(data[:,1], data[:,2], c=cols)
subs[1][2].scatter(data[:,1], data[:,3], c=cols)
 
# Row 3 
 
subs[2][0].scatter(data[:,2], data[:,0], c=cols)
subs[2][1].scatter(data[:,2], data[:,1], c=cols)
subs[2][2].scatter(data[:,2], data[:,3], c=cols)
 
#Row 4 
 
subs[3][0].scatter(data[:,3], data[:,0], c=cols)
subs[3][1].scatter(data[:,3], data[:,1], c=cols)
subs[3][2].scatter(data[:,3], data[:,2], c=cols)
 
plt.show()

Output

I would be interested in adding a legend indicating the red dots represent 'setosa', green dots 'versicolor' and blue dots 'virginica'. That legends would be at the bottom and center of the above picture. How can I do that?

I think I have to play with fig.legend, but I am not sure at all how to do that.


Solution

  • You can loop over the targets in one of the subplots, and make the legend appear outside this plot. Here is what I obtained with your code:

    multiple scatter plots legend

    Here is the code:

    
    import matplotlib.pyplot as plt
    from sklearn.datasets import load_iris
    import numpy as np
    
    fig, subs = plt.subplots(4,3, constrained_layout=True) #setting the shape of the figure in one line as opposed to creating 12 variables
    
    iris = load_iris() ##code as per the example
    data = np.array(iris['data'])
    target_names = iris['target_names']
    targets = np.array(iris['target'])
    
    cd = {0:'r',1:'b',2:"g"}
    cols = np.array([cd[target] for target in targets])
    
    
    # Row 1
    
    subs[0][0].scatter(data[:,0], data[:,1], c=cols)
    subs[0][1].scatter(data[:,0], data[:,2], c=cols)
    subs[0][2].scatter(data[:,0], data[:,3], c=cols)
    
    # Row 2
    
    subs[1][0].scatter(data[:,1], data[:,0], c=cols)
    subs[1][1].scatter(data[:,1], data[:,2], c=cols)
    subs[1][2].scatter(data[:,1], data[:,3], c=cols)
    
    # Row 3
    
    subs[2][0].scatter(data[:,2], data[:,0], c=cols)
    subs[2][1].scatter(data[:,2], data[:,1], c=cols)
    subs[2][2].scatter(data[:,2], data[:,3], c=cols)
    
    # Row 4
    subs[3][0].scatter(data[:,3], data[:,0], c=cols)
    
    # loop for central subplot at last row
    for t, name in zip(np.unique(targets), target_names):
        subs[3][1].scatter(data[targets==t,3], data[targets==t,1], c=cd[t], label=name)
    subs[3][1].legend(bbox_to_anchor=(2, -.2), ncol=len(target_names))  # you can play with bbox_to_anchor for legend position
    
    subs[3][2].scatter(data[:,3], data[:,2], c=cols)
    
    
    plt.savefig('legend')
    

    EDIT: I have also found this post in the matplotlib documentation, where you can extract the scatter elements from the scatter plot directly (without using a for loop). I have tried on the IRIS dataset without being able to make it work.