I have two big arrays to work on. But let's take a look on the following simplified example to get the idea:
I would like to find if an element in data1
is matched to an element in data2
and return the array index in both data1
and data2
if a match is found in form of a new array [index of data1, index of data2]
. For example, with the below set of data1
and data2
, the program will return:
data1 = [[1,1],[2,5],[623,781]]
data2 = [[1,1], [161,74],[357,17],[1,1]]
expected_output = [[0,0],[0,3]]
My current code is as follow:
result = []
for index, item in enumerate(data1):
for index2,item2 in enumerate(data2):
if np.array_equal(item,item2):
result.append([index,index2])
>>> result
[[0, 0], [0, 3]]
This works fine. However, the actual two arrays that I am working on has 0.6 million items each. The above code would be extremely slow. Is there any method to speed up the process?
Probably not the fastest, but easy and reasonably fast: use KDTrees:
>>> data1 = [[1,1],[2,5],[623,781]]
>>> data2 = [[1,1], [161,74],[357,17],[1,1]]
>>>
>>> from operator import itemgetter
>>> from scipy.spatial import cKDTree as KDTree
>>>
>>> def intersect(a, b):
... A = KDTree(a); B = KDTree(b); X = A.query_ball_tree(B, 0.5)
... ai, bi = zip(*filter(itemgetter(1), enumerate(X)))
... ai = np.repeat(ai, np.fromiter(map(len, bi), int, len(ai)))
... bi = np.concatenate(bi)
... return ai, bi
...
>>> intersect(data1, data2)
(array([0, 0]), array([0, 3]))
Two fake data sets 1,000,000
pairs each takes 3
seconds:
>>> from time import perf_counter
>>>
>>> a = np.random.randint(0, 100000, (1000000, 2))
>>> b = np.random.randint(0, 100000, (1000000, 2))
>>> t = perf_counter(); intersect(a, b); s = perf_counter()
(array([ 971, 3155, 15034, 35844, 41173, 60467, 73758, 91585,
97136, 105296, 121005, 121658, 124142, 126111, 133593, 141889,
150299, 165881, 167420, 174844, 179410, 192858, 222345, 227722,
233547, 234932, 243683, 248863, 255784, 264908, 282948, 282951,
285346, 287276, 302142, 318933, 327837, 328595, 332435, 342289,
344780, 350286, 355322, 370691, 377459, 401086, 412310, 415688,
442978, 461111, 469857, 491504, 493915, 502945, 506983, 507075,
511610, 515631, 516080, 532457, 541138, 546281, 550592, 551751,
554482, 568418, 571825, 591491, 594428, 603048, 639900, 648278,
666410, 672724, 708500, 712873, 724467, 740297, 740640, 749559,
752723, 761026, 777911, 790371, 791214, 793415, 795352, 801873,
811260, 815527, 827915, 848170, 861160, 892562, 909555, 918745,
924090, 929919, 933605, 939789, 940788, 940958, 950718, 950804,
997947]), array([507017, 972033, 787596, 531935, 590375, 460365, 17480, 392726,
552678, 545073, 128635, 590104, 251586, 340475, 330595, 783361,
981598, 677225, 80580, 38991, 304132, 157839, 980986, 881068,
308195, 162984, 618145, 68512, 58426, 190708, 123356, 568864,
583337, 128244, 106965, 528053, 626051, 391636, 868254, 296467,
39446, 791298, 356664, 428875, 143312, 356568, 736283, 902291,
5607, 475178, 902339, 312950, 891330, 941489, 93635, 884057,
329780, 270399, 633109, 106370, 626170, 54185, 103404, 658922,
108909, 641246, 711876, 496069, 835306, 745188, 328947, 975464,
522226, 746501, 642501, 489770, 859273, 890416, 62451, 463659,
884001, 980820, 171523, 222668, 203244, 149955, 134192, 369508,
905913, 839301, 758474, 114597, 534015, 381467, 7328, 447698,
651929, 137424, 975677, 758923, 982976, 778075, 95266, 213456,
210555]))
>>> print(s-t)
2.98617472499609