Search code examples
pythonarraysnumpysortinglexicographic

NumPy - np.searchsorted for 2-D arrays


np.searchsorted only for 1D arrays.

I have a lexicographically sorted 2D array, meaning that 0-th row is sorted, then for same values of 0-th row corresponding elements of 1-th row are sorted too, for same values of 1-th row values of 2-th row are sorted too. In other words tuples consisting of columns are sorted.

I have some other 2D array with tuples-columns that need to be inserted into first 2D array into correct positions of columns. For 1D case np.searchsorted was usually used in order to find correct positions.

But for 2D array is there an alternative to np.searchsorted? Something analagous to how np.lexsort is a 2D alternative for 1D np.argsort.

If no such function then can be this functionality implemented in an efficient way using existing numpy functions?

I am interested in efficient solutions for arrays of any dtype including np.object_.

One naive way to handle any dtype case would be to convert each column of both arrays to 1D array (or tuple) and then store these columns as another 1D array of dtype = np.object_. Maybe it is not that naive and could be even fast especially if columns are quite high.


Solution

  • I've created several more advanced strategies.

    Also simple strategy using tuples like in another my answer is implemented.

    Timings of all solutions are measured.

    Most of strategies are using np.searchsorted as underlying engine. To implement these advanced strategies a special wrapping class _CmpIx was used in order to provide custom comparison function (__lt__) for np.searchsorted call.

    1. py.tuples strategy just converts all columns to tuples and stores them as numpy 1D array of np.object_ dtype and then doing regular searchsorting.
    2. py.zip uses python's zip for lazily doing same task.
    3. np.lexsort strategy just uses np.lexsort in order to compare two columns lexicographically.
    4. np.nonzero uses np.flatnonzero(a != b) expression.
    5. cmp_numba uses ahead of time compiled numba code inside _CmpIx wrapper for fast lexicographically lazy comparing of two provided elements.
    6. np.searchsorted uses standard numpy's function but is measured for 1D case only.
    7. for numba strategy whole search algorithm is implemented from scratch using Numba engine, algorithm is based on binary search. There is _py and _nm variants of this algorithm, _nm is much faster as it uses Numba compiler, while _py is same algorithm but un-compiled. Also there is _sorted flavor which does extra optimization of array to be inserted is already sorted.
    8. view1d - methods suggested by @MadPhysicist in this answer. Commented out them in code, because they were returning incorrect answers for most of tests for all key lengths >1, probably due to some problems of raw viewing into array.

    Try it online!

    class SearchSorted2D:
        class _CmpIx:
            def __init__(self, t, p, i):
                self.p, self.i = p, i
                self.leg = self.leg_cache()[t]
                self.lt = lambda o: self.leg(self, o, False) if self.i != o.i else False
                self.le = lambda o: self.leg(self, o, True) if self.i != o.i else True
            @classmethod
            def leg_cache(cls):
                if not hasattr(cls, 'leg_cache_data'):
                    cls.leg_cache_data = {
                        'py.zip': cls._leg_py_zip, 'np.lexsort': cls._leg_np_lexsort,
                        'np.nonzero': cls._leg_np_nonzero, 'cmp_numba': cls._leg_numba_create(),
                    }
                return cls.leg_cache_data
            def __eq__(self, o): return not self.lt(o) and self.le(o)
            def __ne__(self, o): return self.lt(o) or not self.le(o)
            def __lt__(self, o): return self.lt(o)
            def __le__(self, o): return self.le(o)
            def __gt__(self, o): return not self.le(o)
            def __ge__(self, o): return not self.lt(o)
            @staticmethod
            def _leg_np_lexsort(self, o, eq):
                import numpy as np
                ia, ib = (self.i, o.i) if eq else (o.i, self.i)
                return (np.lexsort(self.p.ab[::-1, ia : (ib + (-1, 1)[ib >= ia], None)[ib == 0] : ib - ia])[0] == 0) == eq
            @staticmethod
            def _leg_py_zip(self, o, eq):
                for l, r in zip(self.p.ab[:, self.i], self.p.ab[:, o.i]):
                    if l < r:
                        return True
                    if l > r:
                        return False
                return eq
            @staticmethod
            def _leg_np_nonzero(self, o, eq):
                import numpy as np
                a, b = self.p.ab[:, self.i], self.p.ab[:, o.i]
                ix = np.flatnonzero(a != b)
                return a[ix[0]] < b[ix[0]] if ix.size != 0 else eq
            @staticmethod
            def _leg_numba_create():
                import numpy as np
    
                try:
                    from numba.pycc import CC
                    cc = CC('ss_numba_mod')
                    @cc.export('ss_numba_i8', 'b1(i8[:],i8[:],b1)')
                    def ss_numba(a, b, eq):
                        for i in range(a.size):
                            if a[i] < b[i]:
                                return True
                            elif b[i] < a[i]:
                                return False
                        return eq
                    cc.compile()
                    success = True
                except:    
                    success = False
                    
                if success:
                    try:
                        import ss_numba_mod
                    except:
                        success = False
                
                def odo(self, o, eq):
                    a, b = self.p.ab[:, self.i], self.p.ab[:, o.i]
                    assert a.ndim == 1 and a.shape == b.shape, (a.shape, b.shape)
                    return ss_numba_mod.ss_numba_i8(a, b, eq)
                    
                return odo if success else None
    
        def __init__(self, type_):
            import numpy as np
            self.type_ = type_
            self.ci = np.array([], dtype = np.object_)
        def __call__(self, a, b, *pargs, **nargs):
            import numpy as np
            self.ab = np.concatenate((a, b), axis = 1)
            self._grow(self.ab.shape[1])
            ix = np.searchsorted(self.ci[:a.shape[1]], self.ci[a.shape[1] : a.shape[1] + b.shape[1]], *pargs, **nargs)
            return ix
        def _grow(self, to):
            import numpy as np
            if self.ci.size >= to:
                return
            import math
            to = 1 << math.ceil(math.log(to) / math.log(2))
            self.ci = np.concatenate((self.ci, [self._CmpIx(self.type_, self, i) for i in range(self.ci.size, to)]))
    
    class SearchSorted2DNumba:
        @classmethod
        def do(cls, a, v, side = 'left', *, vsorted = False, numba_ = True):
            import numpy as np
    
            if not hasattr(cls, '_ido_numba'):
                def _ido_regular(a, b, vsorted, lrt):
                    nk, na, nb = a.shape[0], a.shape[1], b.shape[1]
                    res = np.zeros((2, nb), dtype = np.int64)
                    max_depth = 0
                    if nb == 0:
                        return res, max_depth
                    #lb, le, rb, re = 0, 0, 0, 0
                    lrb, lre = 0, 0
                    
                    if vsorted:
                        brngs = np.zeros((nb, 6), dtype = np.int64)
                        brngs[0, :4] = (-1, 0, nb >> 1, nb)
                        i, j, size = 0, 1, 1
                        while i < j:
                            for k in range(i, j):
                                cbrng = brngs[k]
                                bp, bb, bm, be = cbrng[:4]
                                if bb < bm:
                                    brngs[size, :4] = (k, bb, (bb + bm) >> 1, bm)
                                    size += 1
                                bmp1 = bm + 1
                                if bmp1 < be:
                                    brngs[size, :4] = (k, bmp1, (bmp1 + be) >> 1, be)
                                    size += 1
                            i, j = j, size
                        assert size == nb
                        brngs[:, 4:] = -1
    
                    for ibc in range(nb):
                        if not vsorted:
                            ib, lrb, lre = ibc, 0, na
                        else:
                            ibpi, ib = int(brngs[ibc, 0]), int(brngs[ibc, 2])
                            if ibpi == -1:
                                lrb, lre = 0, na
                            else:
                                ibp = int(brngs[ibpi, 2])
                                if ib < ibp:
                                    lrb, lre = int(brngs[ibpi, 4]), int(res[1, ibp])
                                else:
                                    lrb, lre = int(res[0, ibp]), int(brngs[ibpi, 5])
                            brngs[ibc, 4 : 6] = (lrb, lre)
                            assert lrb != -1 and lre != -1
                            
                        for ik in range(nk):
                            if lrb >= lre:
                                if ik > max_depth:
                                    max_depth = ik
                                break
    
                            bv = b[ik, ib]
                            
                            # Binary searches
                            
                            if nk != 1 or lrt == 2:
                                cb, ce = lrb, lre
                                while cb < ce:
                                    cm = (cb + ce) >> 1
                                    av = a[ik, cm]
                                    if av < bv:
                                        cb = cm + 1
                                    elif bv < av:
                                        ce = cm
                                    else:
                                        break
                                lrb, lre = cb, ce
                                    
                            if nk != 1 or lrt >= 1:
                                cb, ce = lrb, lre
                                while cb < ce:
                                    cm = (cb + ce) >> 1
                                    if not (bv < a[ik, cm]):
                                        cb = cm + 1
                                    else:
                                        ce = cm
                                #rb, re = cb, ce
                                lre = ce
                                    
                            if nk != 1 or lrt == 0 or lrt == 2:
                                cb, ce = lrb, lre
                                while cb < ce:
                                    cm = (cb + ce) >> 1
                                    if a[ik, cm] < bv:
                                        cb = cm + 1
                                    else:
                                        ce = cm
                                #lb, le = cb, ce
                                lrb = cb
                                
                            #lrb, lre = lb, re
                                
                        res[:, ib] = (lrb, lre)
                        
                    return res, max_depth
    
                cls._ido_regular = _ido_regular
                
                import numba
                cls._ido_numba = numba.jit(nopython = True, nogil = True, cache = True)(cls._ido_regular)
                
            assert side in ['left', 'right', 'left_right'], side
            a, v = np.array(a), np.array(v)
            assert a.ndim == 2 and v.ndim == 2 and a.shape[0] == v.shape[0], (a.shape, v.shape)
            res, max_depth = (cls._ido_numba if numba_ else cls._ido_regular)(
                a, v, vsorted, {'left': 0, 'right': 1, 'left_right': 2}[side],
            )
            return res[0] if side == 'left' else res[1] if side == 'right' else res
    
    def Test():
        import time
        import numpy as np
        np.random.seed(0)
        
        def round_float_fixed_str(x, n = 0):
            if type(x) is int:
                return str(x)
            s = str(round(float(x), n))
            if n > 0:
                s += '0' * (n - (len(s) - 1 - s.rfind('.')))
            return s
    
        def to_tuples(x):
            r = np.empty([x.shape[1]], dtype = np.object_)
            r[:] = [tuple(e) for e in x.T]
            return r
        
        searchsorted2d = {
            'py.zip': SearchSorted2D('py.zip'),
            'np.nonzero': SearchSorted2D('np.nonzero'),
            'np.lexsort': SearchSorted2D('np.lexsort'),
            'cmp_numba': SearchSorted2D('cmp_numba'),
        }
        
        for iklen, klen in enumerate([1, 1, 2, 5, 10, 20, 50, 100, 200]):
            times = {}
            for side in ['left', 'right']:
                a = np.zeros((klen, 0), dtype = np.int64)
                tac = to_tuples(a)
    
                for itest in range((15, 100)[iklen == 0]):
                    b = np.random.randint(0, (3, 100000)[iklen == 0], (klen, np.random.randint(1, (1000, 2000)[iklen == 0])), dtype = np.int64)
                    b = b[:, np.lexsort(b[::-1])]
                    
                    if iklen == 0:
                        assert klen == 1, klen
                        ts = time.time()
                        ix1 = np.searchsorted(a[0], b[0], side = side)
                        te = time.time()
                        times['np.searchsorted'] = times.get('np.searchsorted', 0.) + te - ts
                        
                    for cached in [False, True]:
                        ts = time.time()
                        tb = to_tuples(b)
                        ta = tac if cached else to_tuples(a)
                        ix1 = np.searchsorted(ta, tb, side = side)
                        if not cached:
                            ix0 = ix1
                        tac = np.insert(tac, ix0, tb) if cached else tac
                        te = time.time()
                        timesk = f'py.tuples{("", "_cached")[cached]}'
                        times[timesk] = times.get(timesk, 0.) + te - ts
    
                    for type_ in searchsorted2d.keys():
                        if iklen == 0 and type_ in ['np.nonzero', 'np.lexsort']:
                            continue
                        ss = searchsorted2d[type_]
                        try:
                            ts = time.time()
                            ix1 = ss(a, b, side = side)
                            te = time.time()
                            times[type_] = times.get(type_, 0.) + te - ts
                            assert np.array_equal(ix0, ix1)
                        except Exception:
                            times[type_ + '!failed'] = 0.
    
                    for numba_ in [False, True]:
                        for vsorted in [False, True]:
                            if numba_:
                                # Heat-up/pre-compile numba
                                SearchSorted2DNumba.do(a, b, side = side, vsorted = vsorted, numba_ = numba_)
                            
                            ts = time.time()
                            ix1 = SearchSorted2DNumba.do(a, b, side = side, vsorted = vsorted, numba_ = numba_)
                            te = time.time()
                            timesk = f'numba{("_py", "_nm")[numba_]}{("", "_sorted")[vsorted]}'
                            times[timesk] = times.get(timesk, 0.) + te - ts
                            assert np.array_equal(ix0, ix1)
    
    
                    # View-1D methods suggested by @MadPhysicist
                    if False: # Commented out as working just some-times
                        aT, bT = np.copy(a.T), np.copy(b.T)
                        assert aT.ndim == 2 and bT.ndim == 2 and aT.shape[1] == klen and bT.shape[1] == klen, (aT.shape, bT.shape, klen)
                        
                        for ty in ['if', 'cf']:
                            try:
                                dt = np.dtype({'if': [('', b.dtype)] * klen, 'cf': [('row', b.dtype, klen)]}[ty])
                                ts = time.time()
                                va = np.ndarray(aT.shape[:1], dtype = dt, buffer = aT)
                                vb = np.ndarray(bT.shape[:1], dtype = dt, buffer = bT)
                                ix1 = np.searchsorted(va, vb, side = side)
                                te = time.time()
                                assert np.array_equal(ix0, ix1), (ix0.shape, ix1.shape, ix0[:20], ix1[:20])
                                times[f'view1d_{ty}'] = times.get(f'view1d_{ty}', 0.) + te - ts
                            except Exception:
                                raise
                    
                    a = np.insert(a, ix0, b, axis = 1)
                
            stimes = ([f'key_len: {str(klen).rjust(3)}'] +
                [f'{k}: {round_float_fixed_str(v, 4).rjust(7)}' for k, v in times.items()])
            nlines = 4
            print('-' * 50 + '\n' + ('', '!LARGE!:\n')[iklen == 0], end = '')
            for i in range(nlines):
                print(',  '.join(stimes[len(stimes) * i // nlines : len(stimes) * (i + 1) // nlines]), flush = True)
                
    Test()
    

    outputs:

    --------------------------------------------------
    !LARGE!:
    key_len:   1,  np.searchsorted:  0.0250
    py.tuples_cached:  3.3113,  py.tuples: 30.5263,  py.zip: 40.9785
    cmp_numba: 25.7826,  numba_py:  3.6673
    numba_py_sorted:  6.8926,  numba_nm:  0.0466,  numba_nm_sorted:  0.0505
    --------------------------------------------------
    key_len:   1,  py.tuples_cached:  0.1371
    py.tuples:  0.4698,  py.zip:  1.2005,  np.nonzero:  4.7827
    np.lexsort:  4.4672,  cmp_numba:  1.0644,  numba_py:  0.2748
    numba_py_sorted:  0.5699,  numba_nm:  0.0005,  numba_nm_sorted:  0.0020
    --------------------------------------------------
    key_len:   2,  py.tuples_cached:  0.1131
    py.tuples:  0.3643,  py.zip:  1.0670,  np.nonzero:  4.5199
    np.lexsort:  3.4595,  cmp_numba:  0.8582,  numba_py:  0.4958
    numba_py_sorted:  0.6454,  numba_nm:  0.0025,  numba_nm_sorted:  0.0025
    --------------------------------------------------
    key_len:   5,  py.tuples_cached:  0.1876
    py.tuples:  0.4493,  py.zip:  1.6342,  np.nonzero:  5.5168
    np.lexsort:  4.6086,  cmp_numba:  1.0939,  numba_py:  1.0607
    numba_py_sorted:  0.9737,  numba_nm:  0.0050,  numba_nm_sorted:  0.0065
    --------------------------------------------------
    key_len:  10,  py.tuples_cached:  0.6017
    py.tuples:  1.2275,  py.zip:  3.5276,  np.nonzero: 13.5460
    np.lexsort: 12.4183,  cmp_numba:  2.5404,  numba_py:  2.8334
    numba_py_sorted:  2.3991,  numba_nm:  0.0165,  numba_nm_sorted:  0.0155
    --------------------------------------------------
    key_len:  20,  py.tuples_cached:  0.8316
    py.tuples:  1.3759,  py.zip:  3.4238,  np.nonzero: 13.7834
    np.lexsort: 16.2164,  cmp_numba:  2.4483,  numba_py:  2.6405
    numba_py_sorted:  2.2226,  numba_nm:  0.0170,  numba_nm_sorted:  0.0160
    --------------------------------------------------
    key_len:  50,  py.tuples_cached:  1.0443
    py.tuples:  1.4085,  py.zip:  2.2475,  np.nonzero:  9.1673
    np.lexsort: 19.5266,  cmp_numba:  1.6181,  numba_py:  1.7731
    numba_py_sorted:  1.4637,  numba_nm:  0.0415,  numba_nm_sorted:  0.0405
    --------------------------------------------------
    key_len: 100,  py.tuples_cached:  2.0136
    py.tuples:  2.5380,  py.zip:  2.2279,  np.nonzero:  9.2929
    np.lexsort: 33.9505,  cmp_numba:  1.5722,  numba_py:  1.7158
    numba_py_sorted:  1.4208,  numba_nm:  0.0871,  numba_nm_sorted:  0.0851
    --------------------------------------------------
    key_len: 200,  py.tuples_cached:  3.5945
    py.tuples:  4.1847,  py.zip:  2.3553,  np.nonzero: 11.3781
    np.lexsort: 66.0104,  cmp_numba:  1.8153,  numba_py:  1.9449
    numba_py_sorted:  1.6463,  numba_nm:  0.1661,  numba_nm_sorted:  0.1651
    

    As it appears from timings numba_nm implementation is the fastest, it outperforms next fastest (py.zip or py.tuples_cached) by 15-100x times. And it has comparable speed (1.85x slower) to standard np.searchsorted for 1D case. Also it appeared to be that _sorted flavor doesn't improve situation (i.e. using information about inserted array being sorted).

    cmp_numba method that is machine-code compiled appears to be around 1.5x times faster on average than py.zip that does same algorithm but in pure python. Due to average maximum equal-key depth being around 15-18 elements numba doesn't gain much speedup here. If depth was hundreds then numba code would probably have a huge speedup.

    py.tuples_cached strategy is faster than py.zip for the case of key length <= 100.

    Also it appears to be that np.lexsort is in fact very slow, either it is not optimized for the case of just two columns, or it spends time doing preprocessing like splitting rows into list, or it does non-lazy lexicographical comparison, the last case is probably the real reason as lexsort slows down with key length grow.

    Strategy np.nonzero is also non-lazy hence works slow too, and slows down with key length growth (but slows down not that fast as np.lexsort does).

    Timings above may be not precise, because my CPU slows down cores frequency 2-2.3 times at random times whenever it is overheated, and it overheats often because it is a powerful CPU inside laptop.