Search code examples
pythonscatter-plotscientific-computing

Scatter plot for points in an array above a given value


import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()

eps = 0.8

X = np.linspace(-1, 1, 100)
Y = np.linspace(-1, 1, 100)
X, Y = np.meshgrid(X, Y)
Z = np.exp(-X**2-Y**2)

data_zero_x = np.array([])
data_zero_y = np.array([])

for i in range(len(X)):
    for j in range(len(Y)):
        if Z[i][j] > eps:
            data_zero_x = np.append(data_zero_x, X[i])
            data_zero_y = np.append(data_zero_y, Y[j])

plt.scatter(data_zero_x, data_zero_y)
plt.show()

Hey there! I would expect this code to produce circular points around the origin since this is where the function Z is above eps=0.8. Instead, I get a rectangular picture out of it. Any ideas what I'm doing wrong here? Also, if there is a better way to code something like this I am all ears.


Solution

  • Try this:

    import matplotlib.pyplot as plt
    import numpy as np
    
    eps = 0.8
    
    X = np.linspace(-1, 1, 100)
    Y = np.linspace(-1, 1, 100)
    X, Y = np.meshgrid(X, Y)
    Z = np.exp(-X**2-Y**2)
    
    mask = Z > eps
    plt.scatter(X[mask], Y[mask])
    plt.show()
    

    Since you are working with a numpy array, there is no need to loop over the complete array and check your condition (> 0.8) for each element. Numpy arrays have overloaded the comparison operators such that when you compare the whole array, it implicitly loops over each element and returns another array with True and False for each element in the original array. You can then use this array as a mask to select elements from other arrays as I did in the code above.

    Also, you don't need the line fig = plt.figure() when you are working with plt.scatter. You only need that you if want to work with the object oriented approach, where you create the figure and the axes explicitly and call the plot methods of the axes objects.