Given a numpy Nx2 numpy array data
of ints (we can assume that data
has no duplicate rows), I need to keep only the rows whose elements satisfy the relationship
(data[i,0] == data[j,1]) & (data[i,1] == data[j,0])
For instance with
import numpy as np
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
I should return
array([[1, 2], # because 2,1 is present
[2, 1], # because 1,2 is present
[6, 6]]) # because 6,6 is present
One verbose way to do this is
def filter_symmetric_pairs(data):
result = np.empty((0,2))
for i in range(len(data)):
for j in range(len(data)):
if (data[i,0] == data[j,1]) & (data[i,1] == data[j,0]):
result = np.vstack([result, data[i,:]])
return result
and I came up with a more concise:
def filter_symmetric_pairs(data):
return data[[row.tolist() in data[:,::-1].tolist() for row in data]]
Can somebody suggest a better numpy idiom?
Here are a couple of different methods you may use to do that. The first one is the "obvious" quadratic solution, which is simple but may give you trouble if you have a big input array. The second one should work as long as you don't have a huge range of numbers in the input, and it has the advantage of working with a linear amount of memory.
import numpy as np
# Input data
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
# Method 1 (quadratic memory)
d0, d1 = data[:, 0, np.newaxis], data[:, 1]
# Compare all values in first column to all values in second column
c = d0 == d1
# Find where comparison matches both ways
c &= c.T
# Get matching elements
res = data[c.any(0)]
print(res)
# [[1 2]
# [2 1]
# [6 6]]
# Method 2 (linear memory)
# Convert pairs into single values
# (assumes positive values, otherwise shift first)
n = data.max() + 1
v = data[:, 0] + (n * data[:, 1])
# Symmetric values
v2 = (n * data[:, 0]) + data[:, 1]
# Find where symmetric is present
m = np.isin(v2, v)
res = data[m]
print(res)
# [[1 2]
# [2 1]
# [6 6]]