I have a long list and its element type is int. I want to find the index of element that equals to a certain number and I use np.where
to achieve this.
The following is my original code,
# suppose x is [1, 1, 2, 3]
y = np.array(x, dtype=np.float32)
idx = list(np.where(y==1)[0])
# output is [0, 1]
After inspecting the code after some time, I realize that I should not use dtype=np.float32
because it would change the datatype of y to float. The correct code should be the following,
# suppose x is [1, 1, 2, 3]
y = np.array(x)
idx = list(np.where(y==1)[0])
# output is also [0, 1]
Surprisingly, these two code snippet produce exactly the same result.
My does the condition for test of equality is handled in numpy.where
when the datatype of array and target are not compatible (int vs float, e.g.)?
NumPy where (source code here) is not concerned with the comparison of data types: its first argument is an array of bool
type. When you write y == 1
, this is an array comparison operation which returns a Boolean array, which is then passed as an argument to where
.
The relevant method is equal
, which you implicitly invoke by writing y == 1
. Its documentation says:
What is compared are values, not types.
For example,
x, y, z = np.float64(0.25), np.float32(0.25), 0.25
These are all of different types, (numpy.float64, numpy.float32, float)
but x == y and y == z and x == z are True. Here it is important that 0.25 is exactly represented in binary system (1/4).
With
x, y, z = np.float64(0.2), np.float32(0.2), 0.2
we see that x == y is False and y == z is False but x == z is True, because Python floats are 64-bit just like np.float64
. Since 1/5 is not exactly represented in binary, using 32 bits vs 64 bits results in two different approximations to 1/5, which is why equality fails: not because of types, but because np.float64(0.2)
and np.float32(0.2)
are actually different values (their difference is about 3e-9).