Search code examples
algorithmcythonquicksortquickselect

A Quickselect C Algorithm faster than C Qsort


I have tried to implement a C QuickSelect algorithm as described in this post (3 way quicksort (C implementation)). However, all I get are performances 5 to 10 times less than the default qsort (even with an initial shuffling). I tried to dig into the original qsort source code as provide here (https://github.com/lattera/glibc/blob/master/stdlib/qsort.c), but it's too complex. Does anybody have a simpler, and better algorithm? Any idea is welcomed. Thanks, NB: my original problem is to try to get the Kth smallest values of an array to the first Kth indices. So I planned to call quickselect K times EDIT 1: Here is the Cython Code as copied and adapted from the link above

cdef void qswap(void* a, void* b, const size_t size) nogil:
    cdef char temp[size]# C99, use malloc otherwise
    #char serves as the type for "generic" byte arrays

    memcpy(temp, b,    size)
    memcpy(b,    a,    size)
    memcpy(a,    temp, size)

cdef void qshuffle(void* base, size_t num, size_t size) nogil: #implementation of Fisher
    cdef int i, j, tmp# create local variables to hold values for shuffle

    for i in range(num - 1, 0, -1): # for loop to shuffle
        j = c_rand() % (i + 1)#randomise j for shuffle with Fisher Yates
        qswap(base + i*size, base + j*size, size)

cdef void partition3(void* base,
                      size_t *low, size_t *high, size_t size,
                      QComparator compar) nogil:       
    # Modified median-of-three and pivot selection.                      
    cdef void *ptr = base
    cdef size_t lt = low[0]
    cdef size_t gt = high[0] # lt is the pivot
    cdef size_t i = lt + 1# (+1 !) we don't compare pivot with itself
    cdef int c = 0

    while (i <= gt):
        c = compar(ptr + i * size, ptr + lt * size)
        if (c < 0):# base[i] < base[lt] => swap(i++,lt++)
            qswap(ptr + lt * size, ptr + i * size, size)
            i += 1
            lt += 1
        elif (c > 0):#base[i] > base[gt] => swap(i, gt--)
            qswap(ptr + i * size, ptr + gt* size, size)
            gt -= 1
        else:#base[i] == base[gt]
            i += 1
    #base := [<<<<<lt=====gt>>>>>>]
    low[0] = lt                                          
    high[0] = gt 


cdef void qselectk3(void* base, size_t lo, size_t hi, 
   size_t size, size_t k, 
   QComparator compar) nogil:                                             
    cdef size_t low = lo                                          
    cdef size_t high = hi                                                      

    partition3(base, &low, &high,  size, compar)    

    if ((k - 1) < low): #k lies in the less-than-pivot partition           
        high = low - 1
        low = lo                      
    elif ((k - 1) >= low and  (k - 1) <= high): #k lies in the equals-to-pivot partition
        qswap(base, base + size*low, size)
        return                              
    else: # k > high => k lies in the greater-than-pivot partition                    
        low = high + 1
        high = hi 
    qselectk3(base, low, high, size, k, compar)

"""
 A selection algorithm to find the nth smallest elements in an unordered list. 
 these elements ARE placed at the nth positions of the input array                                                                         
"""
cdef void qselect(void* base, size_t num, size_t size,
                              size_t n,
                              QComparator compar) nogil:
    cdef int k
    qshuffle(base, num, size)
    for k in range(n):
        qselectk3(base + size*k, 0, num - k - 1, size, 1, compar)

I use python timeit to get the performance of both method pyselect(with N=50) and pysort. Like this

def testPySelect():
    A = np.random.randint(16, size=(10000), dtype=np.int32)
    pyselect(A, 50)
timeit.timeit(testPySelect, number=1)

def testPySort():
    A = np.random.randint(16, size=(10000), dtype=np.int32)
    pysort(A)
timeit.timeit(testPySort, number=1)

Solution

  • Here is a quick implementation for your purpose: qsort_select is a simple implementation of qsort with automatic pruning of unnecessary ranges.

    Without && lb < top, it behaves like the regular qsort except for pathological cases where more advanced versions have better heuristics. This extra test prevents complete sorting of ranges that are outside the target 0 .. (k-1). The function selects the k smallest values and sorts them, the rest of the array has the remaining values in an undefinite order.

    #include <stdio.h>
    #include <stdint.h>
    
    static void exchange_bytes(uint8_t *ac, uint8_t *bc, size_t size) {
        while (size-- > 0) { uint8_t t = *ac; *ac++ = *bc; *bc++ = t; }
    }
    
    /* select and sort the k smallest elements from an array */
    void qsort_select(void *base, size_t nmemb, size_t size,
                      int (*compar)(const void *a, const void *b), size_t k)
    {
        struct { uint8_t *base, *last; } stack[64], *sp = stack;
        uint8_t *lb, *ub, *p, *i, *j, *top;
    
        if (nmemb < 2 || size <= 0)
            return;
    
        top = (uint8_t *)base + (k < nmemb ? k : nmemb) * size;
        sp->base = (uint8_t *)base;
        sp->last = (uint8_t *)base + (nmemb - 1) * size;
        sp++;
        while (sp > stack) {
            --sp;
            lb = sp->base;
            ub = sp->last;
            while (lb < ub && lb < top) {
                /* select middle element as pivot and exchange with 1st element */
                size_t offset = (ub - lb) >> 1;
                p = lb + offset - offset % size;
                exchange_bytes(lb, p, size);
    
                /* partition into two segments */
                for (i = lb + size, j = ub;; i += size, j -= size) {
                    while (i < j && compar(lb, i) > 0)
                        i += size;
                    while (j >= i && compar(j, lb) > 0)
                        j -= size;
                    if (i >= j)
                        break;
                    exchange_bytes(i, j, size);
                }
                /* move pivot where it belongs */
                exchange_bytes(lb, j, size);
    
                /* keep processing smallest segment, and stack largest */
                if (j - lb <= ub - j) {
                    sp->base = j + size;
                    sp->last = ub;
                    sp++;
                    ub = j - size;
                } else {
                    sp->base = lb;
                    sp->last = j - size;
                    sp++;
                    lb = j + size;
                }
            }
        }
    }
    
    int int_cmp(const void *a, const void *b) {
        int aa = *(const int *)a;
        int bb = *(const int *)b;
        return (aa > bb) - (aa < bb);
    }
    
    #define ARRAY_SIZE  50000
    
    int array[ARRAY_SIZE];
    
    int main(void) {
        int i;
        for (i = 0; i < ARRAY_SIZE; i++) {
            array[i] = ARRAY_SIZE - i;
        }
        qsort_select(array, ARRAY_SIZE, sizeof(*array), int_cmp, 50);
        for (i = 0; i < 50; i++) {
            printf("%d%c", array[i], i + 1 == 50 ? '\n' : ',');
        }
        return 0;
    }