Search code examples
pythonmatplotlibscikit-learnvisualizationunsupervised-learning

Draw distance contours in low dimension representation in python


I have a set of n_samples data points. Each data point has n_features (of the order of hundreds or thousands of features). I use K-Means clustering and Euclidean distance to cluster the points into n_clusters. Then I use TSNE to convert my high dimensional input data X (which is n_samples x n_features) to X_low_dim (which is n_samples x 2) to visualize data in two dimensions. Do you know an easy way to draw distance contours from the center of clusters in Python?


Solution

  • I don't know whether I misunderstood the question or others did, but if I got it correctly you want to plot contour plots having the projections of your cluster representatives at the center.
    You can look here for a general approach to contour plots, but taking almost verbatim from that code you could do something like this:

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import cm 
    import scipy.stats as st
    
    def contour_cloud(x, y, cmap):
        xmin, xmax = -10, 10
        ymin, ymax = -10, 10
    
        xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
        positions = np.vstack([xx.ravel(), yy.ravel()])
        values = np.vstack([x, y])
        kernel = st.gaussian_kde(values)
        f = np.reshape(kernel(positions).T, xx.shape)
    
        plt.contourf(xx, yy, f, cmap=cmap, alpha=0.5)
    
    # Assuming to have 2 clusters, split the points into two subsets
    representative_1 = ...  # Shape (2, )
    cluster_1 = ...         # Shape (n_points_cl_1, 2)
    representative_2 = ...  # Shape (2, )
    cluster_2 = ...         # Shape (n_points_cl_2, 2)
    
    plt.scatter(x=representative_1[0], y=representative_1[1], c='b')
    plt.scatter(x=representative_2[0], y=representative_2[1], c='r')
    
    contour_cloud(x=cluster_1[:, 0], y=cluster_1[:, 1], cmap=cm.Blues)
    contour_cloud(x=cluster_2[:, 0], y=cluster_2[:, 1], cmap=cm.Reds)
    
    plt.show()
    

    Set xmin, xmax, ymin, and ymax accordingly to your data.

    This will output something along these lines:

    enter image description here

    Try to play with the parameters to fit your needs, I threw this together in 5 minutes, so it's not really pretty. In the plot above, I sampled 1000 points from two different normal distributions and used their means ((0, 0) and (10, 10)) as representatives.