Search code examples
pythonmatplotlibplotcluster-analysisscatter-plot

Changing marker style in Matplotlib 2D scatter plot with colorbar according to cluster data


I am carrying out clustering and try to plot the result with the scatter plot function of matplotlib.

A dummy data set is :

x = [48.959 49.758 49.887 50.593 50.683 ]
y = [122.310 121.29 120.525 120.252 119.509]
z = [136.993 133.128 143.710 129.088 139.860]

I am plotting x,y and using z as a color axis using the following code

plt.scatter(
x=x, y=y, c=z, label="CO2 Emissions Saved Cumulative", cmap=cm1)

Here is how it looks for the entire data

Now, I performed K means clustering on my dataset and found three clusters. For Example

[0 0 0 0 0 2 1 2 1 2 1 1 2 1 1 1 2 2 2 2 2]

I found the following solution to plot them by differentiating with marker style

ax.scatter(x[cluster == 0], y[cluster == 0], marker="*")
ax.scatter(x[cluster == 1], y[cluster == 1], marker="^")
ax.scatter(x[cluster == 2], y[cluster == 2], marker="s")

now the problem is, using this method, it overwrites the coloraxis as shown in this example image cluster plot example

How can i avoid it to not change the coloraxis of the markers and still use the default z values for the color axis. I want the plot to only change the marker style according to cluster data. and not change the color based on cluster data. Thank you


Solution

  • You need to scale each clusters' z-value to the same scale so that you can have a uniform colorbar for the 3 scatter plots. You can use a Normalize object to do so and pass this normalization to scatter using norm=.

    x = np.array([48.959, 49.758, 49.887, 50.593, 50.683 ])
    y = np.array([122.310, 121.29, 120.525, 120.252, 119.509])
    z = np.array([136.993, 133.128, 143.710, 129.088, 139.860])
    cluster = np.array([0, 1, 0, 2, 2])
    
    mini, maxi = np.min(z), np.max(z)
    norm = plt.Normalize(mini, maxi)
    fig, ax = plt.subplots()
    a = ax.scatter(x[cluster == 0], y[cluster == 0], marker="*", c=z[cluster == 0], norm=norm)
    a = ax.scatter(x[cluster == 1], y[cluster == 1], marker="^", c=z[cluster == 1], norm=norm)
    a = ax.scatter(x[cluster == 2], y[cluster == 2], marker="s", c=z[cluster == 2], norm=norm)
    fig.colorbar(a)
    

    enter image description here