Search code examples
pythonmatplotlibplotscatter-plot

Python 3d scatter plot legend for colors shows only first color


I want to create a 3D scatter plot with legends for the sizes and the colors. However, the legend for the colors only shows the first color in the list.

import matplotlib.pyplot as plt
import matplotlib.colors
# Visualizing 5-D mix data using bubble charts
# leveraging the concepts of hue, size and depth
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
t = fig.suptitle('Wine Residual Sugar - Alcohol Content - Acidity - Total Sulfur Dioxide - Type', fontsize=14)

xs = [1,2,3,5,4]
ys = [6,7,3,5,4]
zs = [1,5,3,9,4]
data_points = [(x, y, z) for x, y, z in zip(xs, ys, zs)]

ss = [100,200,390,500,400]
colors = ['red','red','blue','yellow','yellow']

scatter = ax.scatter(xs, ys, zs, alpha=0.4, c=colors, s=ss)

ax.set_xlabel('Residual Sugar')
ax.set_ylabel('Alcohol')
ax.set_zlabel('Fixed Acidity')


legend1 = ax.legend(*scatter.legend_elements()[0],
                    loc="upper right", title="Classes", labels=colors, bbox_to_anchor=(1.5, 1),prop={'size': 20})
ax.add_artist(legend1)

# produce a legend with a cross section of sizes from the scatter
handles, labels = scatter.legend_elements(prop="sizes", alpha=0.6)
legend2 = ax.legend(handles, labels, loc="upper right", title="Sizes", bbox_to_anchor=(1.5, 0.5), prop={'size': 20})

enter image description here


Solution

  • The issue might result from the fact that matplotlib only receives one series to plot and thus assumes that one legend entry suffices. If I make scatter plots of the red, blue and yellow series individually, then all three classes are displayed correctly in the legend (but it causes issues when plotting the legend with sizes).

    It's perhaps not the most elegant solution, but the legend with classes can be created manually:

    import matplotlib.pyplot as plt
    import matplotlib.colors
    from matplotlib.lines import Line2D
    
    # Visualizing 5-D mix data using bubble charts
    # leveraging the concepts of hue, size and depth
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    t = fig.suptitle('Wine Residual Sugar - Alcohol Content - Acidity - Total Sulfur Dioxide - Type', fontsize=14)
    
    xs = [1,2,3,5,4]
    ys = [6,7,3,5,4]
    zs = [1,5,3,9,4]
    data_points = [(x, y, z) for x, y, z in zip(xs, ys, zs)]
    
    ss = [100,200,390,500,400]
    colors = ['red','red','blue','yellow','yellow']
    
    scatter = ax.scatter(xs, ys, zs, alpha=0.4, c=colors, s=ss)
    
    ax.set_xlabel('Residual Sugar')
    ax.set_ylabel('Alcohol')
    ax.set_zlabel('Fixed Acidity')
    
    # Create additional legend
    UniqueColors = list(dict.fromkeys(colors))
    Legend2Add = []
    for color in UniqueColors:
        Legend2Add.append( Line2D([0], [0], marker='o', color='w', label=color,
               markerfacecolor=color, markersize=15, alpha=0.4) )
    
    # Produce a legend with a cross section of sizes from the scatter
    handles, labels = scatter.legend_elements(prop="sizes", alpha=0.6)
    legend1 = ax.legend(handles,
                        loc="upper right", title="Classes", handles=Legend2Add, bbox_to_anchor=(1.5, 1),prop={'size': 20})
    ax.add_artist(legend1)
    legend2 = ax.legend(handles, labels, loc="upper right", title="Sizes", bbox_to_anchor=(1.5, 0.5), prop={'size': 20})
    
    plt.show()