Search code examples
pythonnumpycomparisonarray-broadcasting

Compare 2D arrays row-wise


This problem is resulting from the spatial analysis of unstructured grids in 3D. I have 2 2D arrays to compare, each with 3 columns for xyz coordinates. One of the array is a reference, the other is evaluated against it (it is the result of CKde tree query against the reference array). In the end I want the number of matching row of the reference. I have tried to find an array concatenation solution but I am lost in the different dimensions

reference=np.array([[0,1,33],[0,33,36],[0,2,36],[1, 33, 34]])
query= np.array([[0,1,33],[0,1,33],[1, 33, 34],[0,33,36],[0,33,36],[0,1,33],[0,33,36]])

Something in the style is where I am heading

filter=reference[:,:,None]==query.all(axis=0)
result = filter.sum(axis=1)

but I cannot find the right way of broadcasting to be able to compare the rows of the 2 arrays. The result should be:

np.array([3,3,0,1])

Solution

  • You need to broadcast the two arrays. Since you cannot compare the 1D array directly, you first need to do a reduction using all on the last dimension. Then you can count the matched rows with sum sum. Here is the resulting code:

    (reference[None,:,:] == query[:,None,:]).all(axis=2).sum(axis=0)
    

    That being said, this solution is not the most efficient for bigger arrays. Indeed for m rows for size n in reference and k rows in query, the complexity of the solution is O(n m k) while the optimal solution is O(n m + n k). This can be achieved using hash maps (aka dict). The idea is to put rows of reference array in a hash map with associated values set to 0 and then for each value of query increase the value of the hash map with the key set to the row of query. One just need to iterate over the hash map to get the final array. Hash map accesses are done in (amortized) constant time. Unfortunately, Python dict does not support array as key since array cannot be hashed, but tuples can be. Here is an example:

    counts = {tuple(row):0 for row in reference}
    
    for row in query:
        key = tuple(row)
        if key in counts:
            counts[key] += 1
    
    print(list(counts.values()))
    

    Which results in printing: [3, 3, 0, 1].

    Note that the order is often not conserved in hash maps, but it should be ok for Python dict. Alternatively, one can use another hash map to rebuild the final array.

    The resulting solution may be slower for small arrays, but it should be better for huge ones.