Search code examples
matplotliblegendcategoriesscatter-plotcolorbar

how to customize color legend when using for loop


I want to draw a 3D scatter, in which the data is colored by group. Here is the data sample:

aa=pd.DataFrame({'a':[1,2,3,4,5],
                 'b':[2,3,4,5,6],
                 'c':[1,3,4,6,9],
                 'd':[0,0,1,2,3],
                 'e':['abc','sdf','ert','hgf','nhkm']})

Here, a, b, c are axis x, y, z. e is the text shown in the scatter. I need d to group the data and show different colors.

Here is my code:

fig = plt.figure()
ax = fig.gca(projection='3d')
zdirs = aa.loc[:,'e'].__array__()
xs = aa.loc[:,'a'].__array__()
ys = aa.loc[:,'b'].__array__()
zs = aa.loc[:,'c'].__array__()
colors = aa.loc[:,'d'].__array__()
colors1=np.where(colors==0,'grey',
                 np.where(colors==1,'yellow',
                          np.where(colors==2,'green',
                                   np.where(colors==3,'pink','red'))))
for i in range(len(zdirs)): #plot each point + it's index as text above
    ax.scatter(xs[i],ys[i],zs[i],color=colors1[i])
    ax.text(xs[i],ys[i],zs[i],  '%s' % (str(zdirs[i])), size=10, zorder=1, color='k')
ax.set_xlabel('a')
ax.set_ylabel('b')
ax.set_zlabel('c')
plt.show()

But I do not know how to put a legend on the plot. I hope my legend is like:

enter image description here

The colors and the numbers should match and be ordered.

Could anyone help me with how to customize the color bar?


Solution

  • First of all, I've taken the liberty to reduce your code a bit:

    • I'd suggest to create a ListedColormap to map integer->color, which allows you to pass the color column via c=aa['d'] (note it's c=, not color=!)
    • you don't need to use __array__() here, in the code below you can directly use aa['a']
    • finally, you can add an empty scatter plot for each color in the ListedColormap, and this can then be rendered correctly by ax.legend()
    import pandas as pd
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    
    from matplotlib.colors import ListedColormap
    import matplotlib.patches as mpatches
    
    aa=pd.DataFrame({'a':[1,2,3,4,5],
                     'b':[2,3,4,5,6],
                     'c':[1,3,4,6,9],
                     'd':[0,0,1,2,3],
                     'e':['abc','sdf','ert','hgf','nhkm']})
    
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    
    cmap = ListedColormap(['grey', 'yellow', 'green', 'pink','red'])
    ax.scatter(aa['a'],aa['b'],aa['c'],c=aa['d'],cmap=cmap)
    
    for x,y,z,label in zip(aa['a'],aa['b'],aa['c'],aa['e']):
        ax.text(x,y,z,label,size=10,zorder=1)
    
    # Create a legend through an *empty* scatter plot
    [ax.scatter([], [], c=cmap(i), label=str(i)) for i in range(len(aa))]
    ax.legend()
    
    ax.set_xlabel('a')
    ax.set_ylabel('b')
    ax.set_zlabel('c')
    
    plt.show()