When i plot clustering result in 2D using this code:
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
# create 2d data
x, label = make_blobs(n_samples=3000, n_features=2, centers=4,
cluster_std=3, random_state=42)
# cluster data
kmeans = KMeans(init="k-means++", n_clusters=4, random_state=42)
kmeans.fit(x)
# plot clusters and centroids
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(111)
ax.scatter(x[:,0],x[:,1], c=kmeans.labels_, cmap='viridis',
edgecolor='k', s=40, alpha = 0.5)
ax.scatter(kmeans.cluster_centers_[:,0], kmeans.cluster_centers_[:,1],
s = 300, c = 'r', marker='*', label = 'Centroid')
ax.set_title("2D Kmeans clustering")
ax.set_xlabel("X")
ax.set_ylabel("Y")
plt.show()
I get the following output:
But when i try to plot clustering result in 3D using this code:
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
# create 3d data
x, label = make_blobs(n_samples=3000, n_features=3, centers=4,
cluster_std=3, random_state=42)
# cluster data
kmeans = KMeans(init="k-means++", n_clusters=4, random_state=42)
kmeans.fit(x)
# plot clusters and centroids
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x[:,0], x[:,1], x[:,2] ,c=kmeans.labels_, cmap='viridis',
edgecolor='k', s=40, alpha = 0.5)
ax.scatter(kmeans.cluster_centers_[:,0], kmeans.cluster_centers_[:,1],
kmeans.cluster_centers_[:,2], s = 300, c = 'r',
marker='*', label = 'Centroid')
ax.set_title("3D Kmeans clustering")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
plt.show()
I get the following output:
As you can see the centroids of each cluster are not visible. I would like to see the centroids stars in the 3d plot, how can i achieve this?
Thanks in advance for any help!
The centroid stars are surrounded by the clouds of points, therefore are not easily visible. You can make them more visible by making the points smaller (with s=10 or smaller) and more transparent (with alpha=0.1 or smaller) as showed by the following code line:
ax.scatter(x[:,0], x[:,1], x[:,2] ,c=kmeans.labels_, cmap='viridis',
edgecolor='k', s=10, alpha = 0.1)