Search code examples
rustvectorizationavx

How can I optimize search in small fixed size array?


I'd like to find the first occurrence of a byte in a 16-byte array. If I write a naive version (either using iterators or with a manual loop), rustc doesn't appear to vectorize (https://godbolt.org/z/fbKfvxTdv).

My vague idea of an optimized solution would be something like

  • broadcast x into some vector register
  • compare it with every element of y at once
  • collect the 16 resulting bools into a u16
  • do something with count_leading_zeros to find which one matched.

Instead, for the manual loop, it's simply unrolled fully with 16 comparisons and jumps, and for the iterator-based version, it doesn't seem to be taking advantage of the known fixed length (exactly 16) at all

pub fn find_first(x: u8, y: &[u8;16]) -> Option<usize> {
    y.iter().position(|w| *w == x)
}

pub fn find_first_manual(x: u8, y: &[u8;16]) -> Option<usize> {
    for i in 0..16 {
        if y[i] == x {
            return Some(i);
        }
    }
    None
}

Solution

  • In nightly exists the Portable SIMD module, which is very convenient.

    Please find below an attempt to realise the algorithm you suggested, as well as the disassembled version from the compiler explorer, and a benchmark showing that, at least on my computer, the SIMD version you suggested is beneficial.

    #![feature(test)] // cargo bench
    #![feature(portable_simd)]
    
    pub fn find_first(
        x: u8,
        y: &[u8; 16],
    ) -> Option<usize> {
        y.iter().position(|w| *w == x)
    }
    
    pub fn find_first_manual(
        x: u8,
        y: &[u8; 16],
    ) -> Option<usize> {
        for i in 0..16 {
            if y[i] == x {
                return Some(i);
            }
        }
        None
    }
    
    pub fn find_first_simd(
        x: u8,
        y: &[u8; 16],
    ) -> Option<usize> {
        use std::simd::cmp::SimdPartialEq;
        use std::simd::u8x16;
        let x = u8x16::splat(x);
        let y = u8x16::from_array(*y);
        let e = x.simd_eq(y);
        e.first_set()
    }
    /*
    in https://godbolt.org/
    rustc nightly with options
      -C opt-level=3 -C target-feature=+avx -C target-feature=+avx2
    gives
    
    example::find_first_simd::hd89fd3b592dbfe03:
            vmovd   xmm0, edi
            vpbroadcastb    xmm0, xmm0
            vpcmpeqb        xmm0, xmm0, xmmword ptr [rsi]
            vpmovmskb       ecx, xmm0
            xor     eax, eax
            test    ecx, ecx
            setne   al
            rep       bsf edx, ecx
            ret
    */
    
    fn main() {
        let y = [0, 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1];
        for x in [0, 4, 9] {
            println!("first({}) {:?}", x, find_first(x, &y));
            println!("first_manual({}) {:?}", x, find_first_manual(x, &y));
            println!("first_simd({}) {:?}", x, find_first_simd(x, &y));
        }
    }
    /*
    first(0) Some(0)
    first_manual(0) Some(0)
    first_simd(0) Some(0)
    first(4) Some(4)
    first_manual(4) Some(4)
    first_simd(4) Some(4)
    first(9) None
    first_manual(9) None
    first_simd(9) None
    */
    
    #[cfg(test)]
    mod tests {
        use super::*;
        extern crate test; // not in Cargo.toml
    
        const REPEAT: usize = 1000;
    
        #[bench]
        fn bench_find_first_simd(b: &mut test::Bencher) {
            b.iter(|| {
                for _ in 0..REPEAT {
                    let y = [0, 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1];
                    for x in [0, 4, 9] {
                        test::black_box(find_first_simd(
                            test::black_box(x),
                            test::black_box(&y),
                        ));
                    }
                }
            });
        }
    
        #[bench]
        fn bench_find_first(b: &mut test::Bencher) {
            b.iter(|| {
                for _ in 0..REPEAT {
                    let y = [0, 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1];
                    for x in [0, 4, 9] {
                        test::black_box(find_first(
                            test::black_box(x),
                            test::black_box(&y),
                        ));
                    }
                }
            });
        }
    
        #[bench]
        fn bench_find_first_manual(b: &mut test::Bencher) {
            b.iter(|| {
                for _ in 0..REPEAT {
                    let y = [0, 1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1];
                    for x in [0, 4, 9] {
                        test::black_box(find_first_manual(
                            test::black_box(x),
                            test::black_box(&y),
                        ));
                    }
                }
            });
        }
    }
    /*
    test tests::bench_find_first        ... bench:       4,282 ns/iter (+/- 74)
    test tests::bench_find_first_manual ... bench:       4,246 ns/iter (+/- 45)
    test tests::bench_find_first_simd   ... bench:       2,716 ns/iter (+/- 45)
    */