Search code examples
pythonnumpyidioms

How to remove non-symmetric pairs in a numpy array?


Given a numpy Nx2 numpy array data of ints (we can assume that data has no duplicate rows), I need to keep only the rows whose elements satisfy the relationship

(data[i,0] == data[j,1]) & (data[i,1] == data[j,0])

For instance with

import numpy as np
data = np.array([[1, 2], 
                 [2, 1],
                 [7, 3],
                 [6, 6],
                 [5, 6]])

I should return

array([[1, 2], # because 2,1 is present
       [2, 1], # because 1,2 is present
       [6, 6]]) # because 6,6 is present

One verbose way to do this is

def filter_symmetric_pairs(data):
  result = np.empty((0,2))
  for i in range(len(data)):
    for j in range(len(data)):
      if (data[i,0] == data[j,1]) & (data[i,1] == data[j,0]):
        result = np.vstack([result, data[i,:]])
  return result

and I came up with a more concise:

def filter_symmetric_pairs(data):
  return data[[row.tolist() in data[:,::-1].tolist() for row in data]]

Can somebody suggest a better numpy idiom?


Solution

  • Here are a couple of different methods you may use to do that. The first one is the "obvious" quadratic solution, which is simple but may give you trouble if you have a big input array. The second one should work as long as you don't have a huge range of numbers in the input, and it has the advantage of working with a linear amount of memory.

    import numpy as np
    
    # Input data
    data = np.array([[1, 2],
                     [2, 1],
                     [7, 3],
                     [6, 6],
                     [5, 6]])
    
    # Method 1 (quadratic memory)
    d0, d1 = data[:, 0, np.newaxis], data[:, 1]
    # Compare all values in first column to all values in second column
    c = d0 == d1
    # Find where comparison matches both ways
    c &= c.T
    # Get matching elements
    res = data[c.any(0)]
    print(res)
    # [[1 2]
    #  [2 1]
    #  [6 6]]
    
    # Method 2 (linear memory)
    # Convert pairs into single values
    # (assumes positive values, otherwise shift first)
    n = data.max() + 1
    v = data[:, 0] + (n * data[:, 1])
    # Symmetric values
    v2 = (n * data[:, 0]) + data[:, 1]
    # Find where symmetric is present
    m = np.isin(v2, v)
    res = data[m]
    print(res)
    # [[1 2]
    #  [2 1]
    #  [6 6]]