In a complex-valued array a
with nsel = ~750000
elements, I repeatedly (>~10^6
iterations) update nchange < ~1000
elements. After each iteration, in the absolute-squared, real-valued array b
, I need to find the indices of the K
largest values (K
can be assumed to be small, for sure K <= ~50
, in practice likely K <= ~10
). The K
indices do not need to be sorted.
The updated values and their indices change in each iteration and they depend on the (a priori) unknown elements of a
corresponding to the largest values of b
and their indices. Nonetheless, let us assume they are essentially random, with exception that one specific element (typically (one of) the largest value(s)) is always included among the updated values. Important: After an update, the new largest value(s) might be among the non-updated elements.
Below is a minimal example. For simplicity, it demonstrates only one of the 10^6 (looped) iterations. We can find the indices of the K
largest values using b.argmax()
(for K = 1
) or b.argpartition()
(arbitrary K
, general case, see https://stackoverflow.com/a/23734295/5269892). However, due to the large size of b
(nsel
), going over the full arrays to find the indices of the largest values is very slow. Combined with the large number of iterations, this forms the bottleneck of a larger code (the nonlinear deconvolution algorithm CLEAN) I am using into which this step is embedded.
I have already asked the question how to find the largest value (the case K = 1
) most efficiently, see Python most efficient way to find index of maximum in partially changed array. The accepted solution relies on accessing b
only partially by splitting the data into chunks and (re-)computing the maxima of only the chunks for which some elements were updated. A speed-up of > 7x
is thus achieved.
According to the author @Jérôme Richard (thanks for your help!), this solution can unfortunately not be easily generalized to K > 1
. As suggested by him, a possible alternative may be a binary search tree. Now my
Questions: How is such a binary tree implemented in practice and how do we then find the indices of the largest values most efficiently (and if possible, easily)? Do you have other solutions for the fastest way to repeatedly find the indices of the K
largest values in the partially updated array?
Note: In each iteration I will need b
(or a copy of it) later again as a numpy array. If possible, the solution should be mostly python-based, calling C from python or using Cython or numba
is ok. I currently use python 3.7.6, numpy 1.21.2
.
import numpy as np
# some array shapes ('nnu_use' and 'nm'), number of total values ('nvals'), number of selected values ('nsel';
# here 'nsel' == 'nvals'; in general 'nsel' <= 'nvals') and number of values to be changed ('nchange' << 'nsel')
nnu_use, nm = 10418//2 + 1, 144
nvals = nnu_use * nm
nsel = nvals
nchange = 1000
# number of largest peaks to be found
K = 10
# fix random seed, generate random 2D 'Fourier transform' ('a', complex-valued), compute power ('b', real-valued),
# and two 2D arrays for indices of axes 0 and 1
np.random.seed(100)
a = np.random.rand(nsel) + 1j * np.random.rand(nsel)
b = a.real ** 2 + a.imag ** 2
inu_2d = np.tile(np.arange(nnu_use)[:,None], (1,nm))
im_2d = np.tile(np.arange(nm)[None,:], (nnu_use,1))
# select 'nsel' random indices and get 1D arrays of the selected 2D indices
isel = np.random.choice(nvals, nsel, replace=False)
inu_sel, im_sel = inu_2d.flatten()[isel], im_2d.flatten()[isel]
def do_update_iter(a, b):
# find index of maximum, choose 'nchange' indices of which 'nchange - 1' are random and the remaining one is the
# index of the maximum, generate random complex numbers, update 'a' and compute updated 'b'
imax = b.argmax()
ichange = np.concatenate(([imax],np.random.choice(nsel, nchange-1, replace=False)))
a_change = np.random.rand(nchange) + 1j*np.random.rand(nchange)
a[ichange] = a_change
b[ichange] = a_change.real ** 2 + a_change.imag ** 2
return a, b, ichange
# do an update iteration on 'a' and 'b'
a, b, ichange = do_update_iter(a, b)
# find indices of largest K values
ilarge = b.argpartition(-K)[-K:]
I tried to implement a Cython solution based on C++ containers (for 64-bit float values). The good news is that it is faster than a naive np.argpartition
. The bad news is that it is quite complex and not much faster: 3~4 times faster.
One main issue is that Cython do not implement the std::multimap
container which is the most useful one. It is possible to implement this container using a std::map<Key, std::vector<Value>>
type but it makes the code significantly more complex and also less efficient (due to the additional cache-unfriendly indirection in memory). If one can guarantee that there is no duplicates in b
, then performance can be significantly better (up to x2) since std::map
can be used instead. Furthermore, Cython do not seems to accept recent C++11/C++17/C++20 features making the code more cumbersome to read/write. This is sad since [some feature like extract
and rvalues-references] can make the code faster.
Another main issue is that the execution time is bounded by cache-misses (>75% on my machine) because the binary RB-trees are not cache friendly. The thing is the overall data structure is very likely bigger than the CPU caches. Indeed, 750_000*(8*2+4) = 15_000_000 bytes
are at least required to store the key-values, not to mention a similar amount of memory is needed to store node pointers of the tree data structure and most processor caches are smaller than 30 MB. This is mainly a problem during the update because of random accesses: each lookup/insert require log2(nsel)
fetches in RAM and the latency of the RAM is typically of several dozens of nanoseconds. Additionally, (C++) RB-trees do not support key updates so a remove+insert is required. I tried to mitigate this problem using a parallel prefetching approach. Unfortunately, it was generally slower in practice...
In practice, the extraction of the K-largest items is very fast (about few microseconds for 1000 items and 750_000 values in the tree) while the update takes about 1.0-1.5 millisecond. Meanwhile, np.argpartition
takes ~4.5 milliseconds.
Some people reported (eg. here) that std::map
is actually quite slow when the number of item is quite big. Thus, it may be a good idea to use another non-standard C++ implementation. I expect B-trees to be faster in this case. The Google Abseil library contains such containers and they are certainly significantly faster. That being said, it certainly require a wrapping some code which can be tedious. Alternatively, one can write a full C++ class and call it from Cython.
Here is the implementation (and an example of usage at the end):
# distutils: language = c++
import numpy as np
cimport numpy as np
cimport cython
# See: https://cython.readthedocs.io/en/latest/src/userguide/wrapping_CPlusPlus.html
from libcpp.vector cimport vector
from libcpp.map cimport map
from libcpp.pair cimport pair
from cython.operator cimport dereference as deref, preincrement as inc
@cython.boundscheck(False) # Deactivate bounds checking
@cython.wraparound(False) # Deactivate negative indexing
cdef class MaxTree:
cdef map[double, vector[int]] data
cdef int itemCount
# Build a tree from `b`
def __init__(self, double[::1] b):
cdef map[double, vector[int]].iterator it
cdef pair[double, vector[int]] node
cdef double val
cdef int i
# Temporary node used to ease insertion
node.second.resize(1)
# Iterate over `b` items so to add them in the tree
for i in range(b.size):
val = b[i]
it = self.data.find(val)
if it == self.data.end():
# Value not found: add a new node
node.first = val
node.second[0] = i
self.data.insert(node)
else:
# Value found: adds a new duplicate in an existing node
deref(it).second.push_back(i)
self.itemCount = b.size
def size(self):
return self.itemCount
# Get the index (in the original `b` array) of the K-largest values
def getKlargest(self, int count):
cdef map[double, vector[int]].reverse_iterator rit
cdef int vecSize
cdef int* vecData
cdef int i, j
cdef int[::1] resultView
if count > self.itemCount:
count = self.itemCount
result = np.empty(count, dtype=np.int32)
resultView = result
i = 0
rit = self.data.rbegin()
while rit != self.data.rend():
vecSize = deref(rit).second.size()
vecData = deref(rit).second.data()
# Note: indices are not always sorted here due to the update
for j in range(vecSize-1, -1, -1):
resultView[i] = vecData[j]
i += 1
count -= 1
if count <= 0:
return resultView
inc(rit)
return result
# Set the values of `b` at the index `index` to `values` and update the tree accordingly
def update(self, double[::1] b, int[::1] index, double[::1] values):
cdef map[double, vector[int]].iterator it
cdef pair[double, vector[int]] node
#cdef pair[map[double, vector[int]].iterator, bool] infos
cdef int idx, i, j, vecSize, indexSize
cdef double oldValue, newValue
cdef int* vecData
assert b.size == self.itemCount
assert index.size == values.size
assert np.min(index) >= 0 and np.max(index) < b.size
# Temporary node used to ease insertion
node.second.resize(1)
for i in range(index.size):
idx = index[i]
oldValue = b[idx]
newValue = values[i]
it = self.data.find(oldValue)
assert it != self.data.end()
# Update the tree
if deref(it).second.size() == 1:
# Remove the node from the tree and add a new one because keys are immutable
# Assume `index` is correct/coherent and the tree is correctly updated for sake of performance
#assert deref(it).second[0] == idx
self.data.erase(it)
node.first = newValue
node.second[0] = idx
infos = self.data.insert(node)
inserted = infos.second
if not inserted:
# Duplicate
it = infos.first
deref(it).second.push_back(idx)
else:
# Tricky case due to duplicates (untested)
vecData = deref(it).second.data()
vecSize = deref(it).second.size()
# Search the element and remove it
for j in range(vecSize):
if vecData[j] == idx:
vecData[j] = vecData[vecSize-1]
deref(it).second.pop_back()
break
# Update `b`
b[idx] = values[i]
# setup.py
from setuptools import setup
from Cython.Build import cythonize
setup(ext_modules=cythonize("maxtree.pyx"))
# Usage:
import numpy as np
import maxtree
np.random.seed(0)
b = np.random.rand(750_000)
nchange = 1_000
ichange = np.random.randint(0, b.size, nchange).astype(np.int32)
tree = maxtree.MaxTree(b)
tree.getKlargest(nchange)
tree.update(b, ichange, b[ichange]*0.999)
python3 setup.py build_ext --inplace -q