Search code examples
loopsrustbounds-check-elimination

Remove bounds checking in Rust loop in attempt to reach optimal compiler output


In an attempt to determine whether I can/should use Rust instead of the default C/C++ I'm looking into various edge cases, mostly with this question in mind: In the 0.1% of cases where it does matter, can I always get compiler output as good as gcc's (with the appropriate optimization flags)? The answer is most likely no, but let's see...

On Reddit there is a rather idiosyncratic example that studies the compiler output of a subroutine for a branchless sort algorithm.

Here is the benchmark C code:

#include <stdint.h>
#include <stdlib.h>
int32_t* foo(int32_t* elements, int32_t* buffer, int32_t pivot)
{
    size_t buffer_index = 0;

    for (size_t i = 0; i < 64; ++i) {
        buffer[buffer_index] = (int32_t)i;
        buffer_index += (size_t)(elements[i] < pivot);
    }
}

And here is the godbolt link with compiler output.

The first attempt with Rust looks like this:

pub fn foo0(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        buffer[buffer_index] = i as i32;
        buffer_index += (elements[i] < pivot) as usize; 
    }
}

There's quite a bit of bounds checking going on, see godbolt.

The next attempt eliminates the first bounds checking:

pub unsafe fn foo1(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        unsafe {
            buffer[buffer_index] = i as i32;
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}

That's a little better (see the same godbolt link as above).

Finally, let's try to remove the bounds checks altogether:

use std::ptr;

pub unsafe fn foo2(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    unsafe {
        for i in 0..buffer.len() {
            ptr::replace(&mut buffer[buffer_index], i as i32);
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}

This produces the same output as foo1, so ptr::replace still performs bounds checking. I'm certainly out of my depth, here, with those unsafe operations. That leads to my two questions:

  • How can the bounds check be eliminated?
  • Does it even make sense to analyze edge cases like this? Or would the Rust compiler see through all this if presented with the whole algorithm instead of only a small fraction thereof.

Regarding the last point, I'm curious, in general, whether Rust can be butchered to the point where it is as "literal", i.e. close to the metal, as C is. Seasoned Rust programmers will probably cringe at this line of investigation, but here it is...


Solution

  • You can achieve this using old-school pointer arithmetics.

    const N: usize = 64;
    pub fn foo2(elements: &Vec<i32>, mut buffer: [i32; N], pivot: i32) -> () {
        assert!(elements.len() >= N);
        let elements = &elements[..N];
        let mut buff_ptr = buffer.as_mut_ptr();
        for (i, &elem) in elements.iter().enumerate(){
            unsafe{
                // SAFETY: We increase ptr strictly less or N times
                *buff_ptr = i as i32;
                if elem < pivot{
                    buff_ptr = buff_ptr.add(1);
                }
            }
        }
    }
    

    This version compiles into:

    example::foo2:
            push    rax
            cmp     qword ptr [rdi + 16], 64
            jb      .LBB7_4
            mov     r9, qword ptr [rdi]
            lea     r8, [r9 + 256]
            xor     edi, edi
    
            // Loop goes here
    .LBB7_2:
            mov     ecx, dword ptr [r9 + 4*rdi]
            mov     dword ptr [rsi], edi
            lea     rax, [rsi + 4]
            cmp     ecx, edx
            cmovge  rax, rsi
            mov     ecx, dword ptr [r9 + 4*rdi + 4]
            lea     esi, [rdi + 1]
            mov     dword ptr [rax], esi
            lea     rsi, [rax + 4]
            cmp     ecx, edx
            cmovge  rsi, rax
            mov     eax, dword ptr [r9 + 4*rdi + 8]
            lea     ecx, [rdi + 2]
            mov     dword ptr [rsi], ecx
            lea     rcx, [rsi + 4]
            cmp     eax, edx
            cmovge  rcx, rsi
            mov     r10d, dword ptr [r9 + 4*rdi + 12]
            lea     esi, [rdi + 3]
            lea     rax, [r9 + 4*rdi + 16]
            add     rdi, 4
            mov     dword ptr [rcx], esi
            lea     rsi, [rcx + 4]
            cmp     r10d, edx
            cmovge  rsi, rcx
            // Conditional branch to the loop beginning
            cmp     rax, r8
            jne     .LBB7_2
            pop     rax
            ret
    .LBB7_4:
            call    std::panicking::begin_panic
            ud2
    

    As you see, loop is unrolled and single branch is loop iteration jump.

    However, I am suprised, that this function is not eliminated because it has no effects: it should be compiled into simple noop. Probably, it would be made such after inlining.

    Also, I would say, that changing parameter to the &mut doesn't change code:

    example::foo2:
            push    rax
            cmp     qword ptr [rdi + 16], 64
            jb      .LBB7_4
            mov     r9, qword ptr [rdi]
            lea     r8, [r9 + 256]
            xor     edi, edi
    .LBB7_2:
            mov     ecx, dword ptr [r9 + 4*rdi]
            mov     dword ptr [rsi], edi
            lea     rax, [rsi + 4]
            cmp     ecx, edx
            cmovge  rax, rsi
            mov     ecx, dword ptr [r9 + 4*rdi + 4]
            lea     esi, [rdi + 1]
            mov     dword ptr [rax], esi
            lea     rsi, [rax + 4]
            cmp     ecx, edx
            cmovge  rsi, rax
            mov     eax, dword ptr [r9 + 4*rdi + 8]
            lea     ecx, [rdi + 2]
            mov     dword ptr [rsi], ecx
            lea     rcx, [rsi + 4]
            cmp     eax, edx
            cmovge  rcx, rsi
            mov     r10d, dword ptr [r9 + 4*rdi + 12]
            lea     esi, [rdi + 3]
            lea     rax, [r9 + 4*rdi + 16]
            add     rdi, 4
            mov     dword ptr [rcx], esi
            lea     rsi, [rcx + 4]
            cmp     r10d, edx
            cmovge  rsi, rcx
            cmp     rax, r8
            jne     .LBB7_2
            pop     rax
            ret
    .LBB7_4:
            call    std::panicking::begin_panic
            ud2
    

    So probably rustc emits that function accepts buffer parameter as pointer in LLVM IR, unfortunately.