Search code examples
pythonnumpymatrixintersection

Intersect float rows in two numpy matrices with precision


I'm looking for rows in B which are close to any of the rows in A

example: eps = 0.1

A = [[1.22, 1.33], [1.45, 1.66]]

B = [[1.25, 1.34], [1.77, 1.66], [1.44, 1.67]]

Result: [[1.22, 1.33], [1.45, 1.66]]


Solution

  • If you are looking to filter to elements in A which are close to any element in B, you can use broadcast and tile to do an exhaustive check:

    import numpy as np
    
    eps = .1
    A = np.array([[1.22, 1.33], [1.45, 1.66]])
    B = np.array([[1.25, 1.34], [1.77, 1.66], [1.44, 1.67]])
    
    
    # broadcast A based on the shape of B
    A_ext = np.broadcast_to(A, (B.shape[0],) + A.shape)
    
    # tile B and reshape, this will allow comparison of all elements in A to all elements in B
    B_ext = np.tile(B, A.shape[0]).reshape(A_ext.shape)
    
    # create the boolean array
    A_bool = np.abs(A_ext - B_ext) < eps
    
    # reduce array to match number of elements in A
    # .all() will result in an array representing which elements in A are close to each element in B
    # .any() represents if A is close to any of the elements in B
    A_mask = A_bool.all(axis = -1).any(axis = 0)
    
    # final result
    A[A_mask]