Search code examples
pythonmatplotlibmachine-learningplotk-means

How does indexing with comma work on Python's plt?


I was following a Machine Learning course, having basic knowledge of Python, following an example in Towards Data Science about K-means Clustering and there is a way of indexing that I didn't ask the professor during the lecture. Source It's the part where the graph is plotted, with the centroids, the author uses indexing like:

plt.scatter(
    X[y_km == 2, 0], X[y_km == 2, 1],
    s=50, c='lightblue',
    marker='v', edgecolor='black',
    label='cluster 3'
)

Does anybody know how this works?

I've tried doing it outside of the plt.scatter, and it isn't helping further than what I already know.


Solution

  • X is an array of 2 columns. You can think of them as x and y coordinates.
    By printing the first 10 rows, you see:

    print(X[0:10])
    
    [[ 2.60509732  1.22529553]
     [ 0.5323772   3.31338909]
     [ 0.802314    4.38196181]
     [ 0.5285368   4.49723858]
     [ 2.61858548  0.35769791]
     [ 1.59141542  4.90497725]
     [ 1.74265969  5.03846671]
     [ 2.37533328  0.08918564]
     [-2.12133364  2.66447408]
     [ 1.72039618  5.25173192]]
    

    y_km is the classification of these coordinates.
    In the example, they are either classified as 0, 1, or 2

    print(y_km[0:10])
    
    [1 0 0 0 1 0 0 1 2 0]
    

    But when you have y_km == 1, these are converted to a list of Booleans

    print((y_km==1)[0:10])
    
    [ True False False False True False False True False False]
    

    So when you call

    X[y_km == 1 , 1]

    Essentially, you are asking to select the values of y_km that are equal to 1, and map them to column 1 of the X array. It will only grab the rows for which y_km is equal to True, and only grab the value from the column specified (i.e. 1)



    And

    X[y_km == 2, 0]

    The values of y_km that are equal to 2, mapped to column 0 of the X array.

    So the first number relates to the classification group that you want to gather, and the second number relates to the column of the X array that you want to retrieve from.