Search code examples
cythonpriority-queuecomparison-operators

Pass a custom comparer to a priority queue in Cython


The Cython libcpp module contains a template for priority_queue, which is great, except for one thing: I cannot pass it a custom comparer (or, at least, I don't know how to).

I need this because I need the priority_queue to do an argsort of sorts rather than a sort (yes, a priority queue is optimal for what I want to do), and I need it to be fast.

Is this possible within Cython, perhaps by wrapping a queue in a custom way, or not at all?

As an example, say I want to sort a vector[int[:]] by one of the elements in a stable way. The actual algorithm is much more complicated.

I decide to sort it by adding it, element-by-element to the priority_queue. However, I don't know how to do that.

My actual operation is something like this question, however I'm merging sorted int[:]s of 1-D elements by a particular element, where the original lists are also ordered by that element.

I don't mind converting the objects to a buffer/pointer.


Solution

  • It is possible to make this work, but I real don't recommend it. The main problems are:

    • C++ containers can't readily hold Python objects (such as memoryviews) unless you're prepared to write reference counting wrapper code (in C++)
    • You can get a C pointer to the first element of a memoryview, but:
      • you must ensure that a reference to the underlying object (that owns the memory) is kept, otherwise Python will free it and you'll be using accessing invalid memory.
      • a pointer loses all information about how long the array is.
    • You're pretty limited as to the comparators you can use (they must be expressable as a cdef function) - for example I've written one that compares the second element of the array, but it would require recompiling to change to comparing the third element.

    Therefore my advice is to find another way of doing it. However:

    You need to write a very small C++ file to typedef the type of priority queue you want. This uses std::function as the comparator, and I've assumed you want to store longs. This file is needed because Cython's template support is pretty limited.

    // cpp_priority_queue.hpp
    #include <functional>
    #include <queue>
    
    using intp_std_func_prority_queue = std::priority_queue<long*,std::vector<long*>,std::function<bool(long*,long*)>>;
    

    You then can't use the libcpp.queue.priority_queue wrapper provided with Cython. Instead, write your own, wrapping the functions you need ("priority_queue_wrap.pyx")

    # distutils: language = c++
    
    from libcpp cimport bool
    
    cdef extern from "cpp_priority_queue.hpp":
        cdef cppclass intp_std_func_prority_queue:
            intp_std_func_prority_queue(...) # get Cython to accept any arguments and let C++ deal with getting them right
            void push(long*)
            long* top()
            void pop()
            bool empty()
    
    cdef bool compare_2nd_element(long* a, long* b):
        # note - no protection if allocated memory isn't long enough
        return a[1] < b[1]
    
    
    def example_function(list _input):
        # takes a list of "memoryviewable" objects
        cdef intp_std_func_prority_queue queue = intp_std_func_prority_queue(compare_2nd_element) # cdef function is convertable to function pointer
    
        cdef long[::1] list_element_mview
        cdef long* queue_element
    
    
        for x in _input:
            #print(x)
            list_element_mview = x
            assert list_element_mview.shape[0] >= 2 # check we have enough elements to compare the second one
            queue.push(&list_element_mview[0]) # push a pointer to the first element
    
        while not queue.empty():
            queue_element = queue.top(); queue.pop()
            print(queue_element[0],queue_element[1]) # print the first two elements (we don't know that we have any more)
    

    I've then created an example function that goes through a list of memoryview compatible objects, converts them to pointers, and adds them to the queue. Finally, it goes through the queue in order and prints what it can. Note that the input list outlives the queue!

    Finally a quick Python test function that creates an appropriate list:

    import priority_queue_wrap
    import numpy as np
    
    a = np.vstack([np.arange(20),np.random.randint(0,10,size=(20,))])
    l = [a[:,n].copy() for n in range(a.shape[1]) ]
    
    print(l)
    priority_queue_wrap.example_function(l)
    

    In summary, Python objects + Cython + C++ is a mess: I don't recommend doing it this way (but you can try!)