Search code examples
pythonmatplotlibscatter-plot

Matplotlib scatterplot, color as function of element in array


I'm trying to plot data with different colors depending on their classification. The data is in an nx3 array, with the first column the x position, the second column the y position, and the third column an integer defining their categorical value. I can do this by running a for loop over the entire array and plotting each point individually, but I have found that doing so massively slows down everything.

So, this works.

data = np.loadtxt('data.csv', delimiter = ",")
colors = ['r', 'g', 'b']

fig = plt.figure():
for i in data:
plt.scatter(i[0], i[1], color = colors[int(i[2] % 3]))
plt.show()

This does not work, but I want it to, as something along this line would avoid using a for loop.

data = np.loadtxt('data.csv', delimiter = ",")
colors = ['r', 'g', 'b']

fig = plt.figure():
plt.scatter(data[:,0], data[:,1], color = colors[int(data[:,2]) % 3])
plt.show()

Solution

  • Your code doesn't work because your x and y values are arrays from the data while color is not. So, you have to define it as an array. Just a look at the matplotlib page: https://matplotlib.org/stable/gallery/shapes_and_collections/scatter.html They have this example there:

    import numpy as np
    import matplotlib.pyplot as plt
    
    # Fixing random state for reproducibility
    np.random.seed(19680801)
    
    N = 50
    x = np.random.rand(N)
    y = np.random.rand(N)
    colors = np.random.rand(N)
    area = (30 * np.random.rand(N))**2  # 0 to 15 point radii
    
    plt.scatter(x, y, s=area, c=colors, alpha=0.5)
    plt.show()
    

    Here, you have the same x and y. Probably, you won't need s. Color is an array. You can do something as follows:

    colors = ['r', 'g', 'b']
    colors_list = [colors[int(i) % 3] for i in data[:,2]]
    plt.scatter(data[:,0], data[:,1], c = colors_list)
    

    Just note that since I don't have the data to test it, you may need to tweak the code just in case.