OS: Windows 10 (x64), Build 1909
Python Version: 3.8.10
Numpy Version: 1.21.2
Given two 2D (N, 3)
Numpy arrays of (x, y, z)
floating-point data points, what is the Pythonic (vectorized) way to find the indices in one array where points are equal to the points in the other array?
(NOTE: My question differs in that I need this to work with real-world data sets where the two data sets may differ by floating point error. Please read on below for details.)
Very similar questions have been asked many times:
SO Post 1 provides a working list comprehension solution, but I am looking for a solution that will scale well to large data sets (i.e. millions of points):
Code 1:
import numpy as np
if __name__ == "__main__":
big_array = np.array(
[
[1.0, 2.0, 1.2],
[5.0, 3.0, 0.12],
[-1.0, 14.0, 0.0],
[-9.0, 0.0, 13.0],
]
)
small_array = np.array(
[
[5.0, 3.0, 0.12],
[-9.0, 0.0, 13.0],
]
)
inds = [
ndx
for ndx, barr in enumerate(big_array)
for sarr in small_array
if all(sarr == barr)
]
print(inds)
Output 1:
[1, 2]
Attempting the solution of SO Post 3 (similar to SO Post 2), but using floats does not work (and I suspect something using np.isclose
will be needed):
Code 3:
import numpy as np
if __name__ == "__main__":
big_array = np.array(
[
[1.0, 2.0, 1.2],
[5.0, 3.0, 0.12],
[-1.0, 14.0, 0.0],
[-9.0, 0.0, 13.0],
],
dtype=float,
)
small_array = np.array(
[
[5.0, 3.0, 0.12],
[-9.0, 0.0, 13.0],
],
dtype=float,
)
inds = np.nonzero(
np.in1d(big_array.view("f,f").reshape(-1), small_array.view("f,f").reshape(-1))
)[0]
print(inds)
Output 3:
[ 3 4 5 8 9 10 11]
I have tried numpy.isin
with np.all
and np.argwhere
inds = np.argwhere(np.all(np.isin(big_array, small_array), axis=1)).reshape(-1)
which works (and, I argue, much more readable and understandable; i.e. pythonic), but will not work for real-world data sets containing floating-point errors:
import numpy as np
if __name__ == "__main__":
big_array = np.array(
[
[1.0, 2.0, 1.2],
[5.0, 3.0, 0.12],
[-1.0, 14.0, 0.0],
[-9.0, 0.0, 13.0],
],
dtype=float,
)
small_array = np.array(
[
[5.0, 3.0, 0.12],
[-9.0, 0.0, 13.0],
],
dtype=float,
)
small_array_fpe = np.array(
[
[5.0 + 1e-9, 3.0 + 1e-9, 0.12 + 1e-9],
[-9.0 + 1e-9, 0.0 + 1e-9, 13.0 + 1e-9],
],
dtype=float,
)
inds_no_fpe = np.argwhere(np.all(np.isin(big_array, small_array), axis=1)).reshape(-1)
inds_with_fpe = np.argwhere(
np.all(np.isin(big_array, small_array_fpe), axis=1)
).reshape(-1)
print(f"No Floating Point Error: {inds_no_fpe}")
print(f"With Floating Point Error: {inds_with_fpe}")
print(f"Are 5.0 and 5.0+1e-9 close?: {np.isclose(5.0, 5.0 + 1e-9)}")
Output:
No Floating Point Error: [1 3]
With Floating Point Error: []
Are 5.0 and 5.0+1e-9 close?: True
How can I make my above solution work (on data sets with floating point error) by incorporating np.isclose
? Alternative solutions are welcome.
NOTE: Since small_array
is a subset of big_array
, using np.isclose
directly doesn't work because the shapes won't broadcast:
np.isclose(big_array, small_array_fpe)
yields
ValueError: operands could not be broadcast together with shapes (4,3) (2,3)
Currently, the only working solution I have is
inds_with_fpe = [
ndx
for ndx, barr in enumerate(big_array)
for sarr in small_array_fpe
if np.all(np.isclose(sarr, barr))
]
As @Michael Anderson already mentioned this can be implemented using a kd-tree. In comparsion to your answer this solution is using an absolute error. If this is acceptable or not depends on the problem.
Example
import numpy as np
from scipy import spatial
def find_nearest(big_array,small_array,tolerance):
tree_big=spatial.cKDTree(big_array)
tree_small=spatial.cKDTree(small_array)
return tree_small.query_ball_tree(tree_big,r=tolerance)
Timings
big_array=np.random.rand(100_000,3)
small_array=np.random.rand(1_000,3)
big_array[1000:2000]=small_array
%timeit find_nearest(big_array,small_array,1e-9) #find all pairs within a distance of 1e-9
#55.7 ms ± 830 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#A. Hendry
%timeit np.argwhere(np.isclose(small_array, big_array[:, None, :]).all(-1).any(-1)).reshape(-1)
#3.24 s ± 19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)