Search code examples
performanceassemblysimdmicro-optimizationavx

Using SIMD/AVX/SSE for tree traversal


I am currently researching whether it would be possible to speed up a van Emde Boas (or any tree) tree traversal. Given a single search query as input, already having multiple tree nodes in the cache line (van emde Boas layout), tree traversal seems to be instruction-bottlenecked.

Being kinda new to SIMD/AVX/SSE instructions, I would like to know from experts in that topic whether it would be possible to compare multiple nodes at once to a value and then find out which tree path to follow further on. My research lead to the following question:

How many CPU cycles/instructions are wasted on construction of SIMD/AVX/SSE register etc.. This would make its use for the wayne, if construction takes more time than traversing the whole sub-tree manually (2+4+8 nodes in 1 cache line of size 64 bytes).

How many CPU cycles/instructions are wasted on finding the proper SIMD/AVX/SSE register holding the answer of which path to follow on ? Could anybody come up with a smart way so that those findMinimumInteger AVX instructions could be used to decide that in 1 (??) CPU cycle ?

What is your guess ?

Another, more tricky approach to speed up tree traversal would be to have multiple search queries run down at once, when there is high probability to land in nodes closely together in the last tree level. Any guesses on this ? Ofc it would have to put those queries aside that do not belong to the same sub-tree any longer and then recursively find them after finishing the first "parallel traversal" of the tree.. The tree queries have sequential, though not constant access patterns (query[i] always < than query[i+1]).

Important: this stuff is about integer tree's, which is why van Emde Boas Tree is used (maybe x-fast/y-fast tries later on)

I am curious about what is your 50 cents on this issue, given that one might be interested in the highest achievable performance on large scale tree's. Thank you in advance for your time spending on this though :-)


Solution

  • I've used SSE2/AVX2 to help perform a B+tree search. Here's code to perform a "binary search" on a full cache line of 16 DWORDs in AVX2:

    // perf-critical: ensure this is 64-byte aligned. (a full cache line)
    union bnode
    {
        int32_t i32[16];
        __m256i m256[2];
    };
    
    // returns from 0 (if value < i32[0]) to 16 (if value >= i32[15]) 
    unsigned bsearch_avx2(bnode const* const node, __m256i const value)
    {
        __m256i const perm_mask = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
    
        // compare the two halves of the cache line.
    
        __m256i cmp1 = _mm256_load_si256(&node->m256[0]);
        __m256i cmp2 = _mm256_load_si256(&node->m256[1]);
    
        cmp1 = _mm256_cmpgt_epi32(cmp1, value); // PCMPGTD
        cmp2 = _mm256_cmpgt_epi32(cmp2, value); // PCMPGTD
    
        // merge the comparisons back together.
        //
        // a permute is required to get the pack results back into order
        // because AVX-256 introduced that unfortunate two-lane interleave.
        //
        // alternately, you could pre-process your data to remove the need
        // for the permute.
    
        __m256i cmp = _mm256_packs_epi32(cmp1, cmp2); // PACKSSDW
        cmp = _mm256_permutevar8x32_epi32(cmp, perm_mask); // PERMD
    
        // finally create a move mask and count trailing
        // zeroes to get an index to the next node.
    
        unsigned mask = _mm256_movemask_epi8(cmp); // PMOVMSKB
        return _tzcnt_u32(mask) / 2; // TZCNT
    }
    

    You'll end up with a single highly predictable branch per bnode, to test if the end of the tree has been reached.

    This should be trivially scalable to AVX-512.

    To preprocess and get rid of that slow PERMD instruction, this would be used:

    void preprocess_avx2(bnode* const node)
    {
        __m256i const perm_mask = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
        __m256i *const middle = (__m256i*)&node->i32[4];
    
        __m256i x = _mm256_loadu_si256(middle);
        x = _mm256_permutevar8x32_epi32(x, perm_mask);
        _mm256_storeu_si256(middle, x);
    }