import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure()
eps = 0.8
X = np.linspace(-1, 1, 100)
Y = np.linspace(-1, 1, 100)
X, Y = np.meshgrid(X, Y)
Z = np.exp(-X**2-Y**2)
data_zero_x = np.array([])
data_zero_y = np.array([])
for i in range(len(X)):
for j in range(len(Y)):
if Z[i][j] > eps:
data_zero_x = np.append(data_zero_x, X[i])
data_zero_y = np.append(data_zero_y, Y[j])
plt.scatter(data_zero_x, data_zero_y)
plt.show()
Hey there! I would expect this code to produce circular points around the origin since this is where the function Z is above eps=0.8. Instead, I get a rectangular picture out of it. Any ideas what I'm doing wrong here? Also, if there is a better way to code something like this I am all ears.
Try this:
import matplotlib.pyplot as plt
import numpy as np
eps = 0.8
X = np.linspace(-1, 1, 100)
Y = np.linspace(-1, 1, 100)
X, Y = np.meshgrid(X, Y)
Z = np.exp(-X**2-Y**2)
mask = Z > eps
plt.scatter(X[mask], Y[mask])
plt.show()
Since you are working with a numpy array, there is no need to loop over the complete array and check your condition (> 0.8) for each element. Numpy arrays have overloaded the comparison operators such that when you compare the whole array, it implicitly loops over each element and returns another array with True
and False
for each element in the original array. You can then use this array as a mask to select elements from other arrays as I did in the code above.
Also, you don't need the line fig = plt.figure()
when you are working with plt.scatter
. You only need that you if want to work with the object oriented approach, where you create the figure and the axes explicitly and call the plot methods of the axes objects.