Search code examples
assemblyapple-m1simdarm64neon

Fastest way to search an array on m1 mac


I am trying to load an array of u16s from memory and find the first element that is less than some number, as fast as possible on an M1 mac. I have been looking through the NEON instructions, but I wasn't able to find a good way to do it. There are comparison vector instructions, but they leave you with vectors of elements of all 1 or 0. How would you get that into a general purpose register? And is there a way to break a loop by testing a vector instruction?


Solution

  • // int32_t searchArrayU16(uint16_t *pArray, uint16_t threshold, uint32_t len);
    // assert(len & 127 == 0);
    // assert(len >= 128);
    
        .arch armv8-a
        .global searchArrayU16
        .text
    
    
    pArray0 .req    x0
    thresh  .req    w1
    len     .req    x2
    stride  .req    x3
    pArray1 .req    x4
    count   .req    w5
    val0    .req    x6
    val0w   .req    w6
    val1    .req    x7
    val1w   .req    w7
    
    .balign 64
    .func
    searchArrayU16:
        adr     x15, 2f
        dup     v0.8h, thresh
        ld1r    {v1.2d}, [x15]
        mov     stride, #128
        add     pArray1, pArray0, #64
        mov     count, #0
    
        b       1f
    
    .balign 8
    2:
    .byte   1, 2, 4, 8, 16, 32, 64, 128
    
    .balign 64
    1:
        ld1     {v16.8h-v19.8h}, [pArray0], stride
        ld1     {v20.8h-v23.8h}, [pArray1], stride
        ld1     {v24.8h-v27.8h}, [pArray0], stride
        ld1     {v28.8h-v31.8h}, [pArray1], stride
    
        cmhi    v16.8h, v0.8h, v16.8h
        cmhi    v17.8h, v0.8h, v17.8h
        cmhi    v18.8h, v0.8h, v18.8h
        cmhi    v19.8h, v0.8h, v19.8h
        cmhi    v20.8h, v0.8h, v20.8h
        cmhi    v21.8h, v0.8h, v21.8h
        cmhi    v22.8h, v0.8h, v22.8h
        cmhi    v23.8h, v0.8h, v23.8h
        cmhi    v24.8h, v0.8h, v24.8h
        cmhi    v25.8h, v0.8h, v25.8h
        cmhi    v26.8h, v0.8h, v26.8h
        cmhi    v27.8h, v0.8h, v27.8h
        cmhi    v28.8h, v0.8h, v28.8h
        cmhi    v29.8h, v0.8h, v29.8h
        cmhi    v30.8h, v0.8h, v30.8h
        cmhi    v31.8h, v0.8h, v31.8h
    
        uzp1    v16.16b, v16.16b, v17.16b
        uzp1    v18.16b, v18.16b, v19.16b
        uzp1    v20.16b, v20.16b, v21.16b
        uzp1    v22.16b, v22.16b, v23.16b
        uzp1    v24.16b, v24.16b, v25.16b
        uzp1    v26.16b, v26.16b, v27.16b
        uzp1    v28.16b, v28.16b, v29.16b
        uzp1    v30.16b, v30.16b, v31.16b
    
        and     v16.16b, v16.16b, v1.16b
        and     v18.16b, v18.16b, v1.16b
        and     v20.16b, v20.16b, v1.16b
        and     v22.16b, v22.16b, v1.16b
        and     v24.16b, v24.16b, v1.16b
        and     v26.16b, v26.16b, v1.16b
        and     v28.16b, v28.16b, v1.16b
        and     v30.16b, v30.16b, v1.16b
    
        addp    v16.16b, v16.16b, v18.16b
        addp    v20.16b, v20.16b, v22.16b
        addp    v24.16b, v24.16b, v26.16b
        addp    v28.16b, v28.16b, v30.16b
    
        addp    v16.16b, v16.16b, v20.16b
        addp    v24.16b, v24.16b, v28.16b
    
        add     count, count, #128
        addp    v16.16b, v16.16b, v24.16b
    
    // total pipeline stall here
    
        mov     val0, v16.d[0]
        mov     val1, v16.d[1]
    
        orr     x15, val0, val1
        cbnz    x15, 1f // found a match!!!
    
        cmp     len, count, uxtw
        b.hi    1b
    .balign 16
        mov     w0, #-1     // no match found
        ret
    
    .balign 16
    1:
        rbit    val0, val0
        rbit    val1, val1
        cmp     val0, #0
        sub     w0, count, #128
        sub     w1, count, #64
        clz     val0, val0
        clz     val1, val1
        add     w0, w0, val0w
        add     w1, w1, val1w
        csel    w0, w0, w1, ne
        ret
    .endfunc
    .end
    

    Here you are. It returns -1 when no match is found.
    It should work on all armv8-a cores or above.