Search code examples
assemblyoptimizationmicro-optimizationloop-unrollingy86

The most efficient way of counting positive, negative and zero number using loop unrolling


Say I have the following instruction, simply checks if a number is positive or not (negative or zero) and if it was positive add 1 to our counter (and we don't care if the numbers is negative or zero). I can do this with a simple loop unrolling:

Loop:   
    mrmovq (%rdi), %r10     # read val[0] from src
    andq %r10, %r10         # val[0] <= 0?
    jle Npos1               # if so, goto Npos:
    iaddq $1, %rax          # count++

Npos1:      
    mrmovq 8(%rdi), %r11    # read val[1] from src+8
    andq %r11,%r11          # val <= 0?
    jle Npos2               # if so, goto next npos:
    iaddq $1, %rax

Npos2:      
    mrmovq 16(%rdi), %r11   # read val[2] from src+8
    andq %r11,%r11          # val <= 0?
    jle Npos3               # if so, goto next npos:
    iaddq $1, %rax

My question is how I can get the same efficient structure if I want to check also for being zero or negative. In this case I'll have three counters (one for pos, one for neg and one for zero) One inefficient way would be like this. I am trying to make the same structure as the previous example (we are storing positive counts in %rax, negatives in %rbx and zeros in %rcx) :

Loop:   mrmovq (%rdi), %r10 # read val from src...
        andq %r10, %r10     # val <= 0?
        jle Npos            # if so, goto Npos:
        irmovq $1, %r11
        addq %r11, %rax     # Count positives in rax - count_pos++ 
        jmp Rest 
Npos:   andq %r10, %r10     # Not positive 
        je Zero
        irmovq $1, %r11
        addq %r11, %rbx     # Count negatives in rbx - count_neg++
        jmp Rest
Zero:   irmovq $1, %r11
        addq %r11, %rcx     # Count zeroes in rcx - count_zero++
Rest:   irmovq $1, %r10
        subq %r10, %rdx     # len--
        irmovq $8, %r10
        addq %r10, %rdi     # src++
        addq %r10, %rsi     # dst++
        andq %rdx,%rdx      # len > 0?
        jg Loop             # if so, goto Loop:

Solution

  • update: see the very end for a non-branching version that should be much better, and trivial to unroll. But the rest of the answer is still worth reading, IMO.

    I did find a way to save a couple instructions executed per value tested, with an unroll, but it's pretty minor compared to what I managed with a well-optimized version that used loop tail duplication. (see below).


    y86 is too stripped-down to allow efficient code compared to real architectures in a lot of cases. For one thing, there doesn't seem to be a way to conditionally increment without clobbering flags. (x86 has lea rax, [rax+1]).

    I don't see a way to save a lot of instructions by only counting positive and zero, and calculating the negative count from that after the loop. You still have to branch to test each value. update: no you don't, because you can emulate x86's setcc using y86's cmov!


    However, I did find several big improvements in your code:

    • reuse the flags set by the first test, instead of re-testing

    • Another major thing is to hoist the %r11 = 1 out of the loop, so you can just increment with one insn. Setting up constants in registers is a really common thing even in real code. Most ISAs (including RISC load-store machines) have add-immediate instructions, like x86's add $1, %rax, but y86 doesn't so it needs this technique even for increments (x86 inc %rax)!

    • sub sets flags, so use them instead of doing a separate test.

    Style issues:

    With descriptive label names, you don't need as many comments.

    Also, indent your operands to a consistent column, rather than just a single space after the variable-length mnemonic. It's more readable. I like indenting branch targets less, to make them stand out, but there are so many branches in this code that it actually just looks ugly :/

            irmovq  $1, %r11     # hoisted out of the loop
            irmovq  $8, %r8
    
    Loop:   mrmovq  (%rdi), %r10 # read val from src...
            andq    %r10, %r10   # set flags from val
            jle    not_positive
            addq    %r11, %rax   # Count positives in rax - count_pos++ 
            jmp    Rest 
    not_positive:
            je     Zero          # flags still from val
            addq    %r11, %rbx   # Count negatives in rbx - count_neg++
            jmp    Rest
    Zero:
            addq    %r11, %rcx   # Count zeroes in rcx - count_zero++
    Rest:
            addq    %r8, %rdi    # src+=8
            //addq %r8, %rsi     # dst++  // why?  Not used...  Also note that si stands for source index, so prefer keeping src pointers in rsi, and dest pointers in rdi for human readability.
            subq    %r11, %rdx   # len--, setting flags
            jg     Loop          # } while( len-- > 1).  fall through when rdx=0
    

    loop tail duplication:

    You can increase code size but decrease the number of instructions that actually run by duplicating the loop tail.

    I also restructured the loop so there's only one taken branch per iteration in the loop body.

            irmovq $1, %r11       # hoisted out of the loop
            irmovq $8, %r8
    
    Loop:   mrmovq (%rdi), %r10   # read val from src...
            addq   %r8, %rdi      # src+=8 for next iteration
    
            andq   %r10, %r10     # set flags from val
            je    Zero
            jl    Negative
            # else Positive
            addq   %r11, %rax     # Count positives in rax - count_pos++ 
    
            subq   %r11, %rdx
            jg    Loop
            jmp   end_loop
    Negative:
            addq   %r11, %rbx     # Count negatives in rbx - count_neg++
    
            subq   %r11, %rdx
            jg    Loop 
            jmp   end_loop
    Zero:
            addq   %r11, %rcx     # Count zeroes in rcx - count_zero++
    
            subq   %r11, %rdx     # len--, setting flags
            jg Loop               # } while( len-- > 1).  fall through when rdx=0
    end_loop:
    

    There's not a lot to gain from unrolling, since the loop body is so big. If you did, you might do it like this:

    Note that we only update and check len once per iteration. This means we need a cleanup loop, but only decrementing and checking one at a time would mostly defeat the purpose of unrolling.

    Unrolled by two, with tail duplication

            irmovq $1, %r11       # hoisted out of the loop
            irmovq $2, %r12
            irmovq $16, %r8
    
            sub    %r12, %rdi
            jl     end_loop       # unrolled loop requires len >= 2
    
    Loop:   mrmovq (%rdi), %r10   # read val from src...
            mrmovq 8(%rdi), %r9   # read next val here so we don't have to duplicate this
            addq   %r8, %rdi      # src+=16 for next iteration
    
            andq   %r10, %r10     # set flags from val
            je    Zero1
            jl    Negative1
            # else Positive1
            addq   %r11, %rax     # Count positives in rax - count_pos++ 
    
            andq   %r9, %r9       # set flags from val2
            je    Zero2
            jl    Negative2
    Positive2:
            addq   %r11, %rax     # Count positives in rax - count_pos++ 
    
            subq   %r12, %rdx     # loop tail
            jge   Loop
            jmp   end_loop
    
    Negative1:
            addq   %r11, %rbx     # Count negatives in rbx - count_neg++
    
            andq   %r9, %r9       # set flags from val2
            je    Zero2
            jg    Positive2
    Negative2:
            addq   %r11, %rbx     # Count negatives in rbx - count_neg++
    
            subq   %r12, %rdx     # loop tail
            jge   Loop 
            jmp   end_loop
    
    Zero1:
            addq   %r11, %rcx     # Count zeroes in rcx - count_zero++
    
            andq   %r9, %r9       # set flags from val2
            jg    Positive2
            jl    Negative2
    Zero2:
            addq   %r11, %rcx     # Count zeroes in rcx - count_zero++
    
            subq   %r12, %rdx     # len-=2, setting flags
            jge   Loop            # fall through when rdx=-1 or -2
    end_loop:
    
    # loop epilogue to handle cases where there was an odd number of elements, so rdx=-1 here:
            add   %r12, %rdx
            jne  all_done
            #else there's one more to do
            #... load and test a single element
    

    I wouldn't be surprised if there's an off-by-one error in my loop conditions or something.


    Like Jester pointed out in comments, x86 can count negatives with

    sar   $63, %r10     # broadcast the sign bit to all bits: -1 or 0
    sub   %r10, %rbx    # subtracting -1 (or 0): i.e. add 1 (or 0)
    

    Update: non-branching version that uses cmov to emulate setcc

    We can use cmov to set a register to 0 or 1, and then add that to our count. This avoids all branching. (0 is the additive identity. This basic technique works for any operation that has an identity element. e.g. all-ones is the identity element for AND. 1 is the identity element for multiply.)

    Unrolling this is straightforward, but there are 3 instructions of loop overhead compared to 8 instructions that need to be repeated. The gains would be fairly small.

            irmovq $1, %r11       # hoisted out of the loop
            irmovq $8, %r8
            mov    %rdx, %rbx     # neg_count is calculated later
    
    Loop:   mrmovq (%rdi), %r10   # read val from src...
            addq   %r8, %rdi      # src+=16 for next iteration
    
            andq   %r10, %r10     # set flags from val
    
            irmovq $0, %r13
            cmovg  %r11, %r13     # emulate setcc
            irmovq $0, %r14
            cmove  %r11, %r14
    
            add    %r13, %rax     # pos_count  += (val >  0)
            add    %r14, %rcx     # zero_count += (val == 0)
    
            subq   %r11, %rdx     # len-=1, setting flags
            jg    Loop            # fall through when rdx=0
    end_loop:
    
            sub    %rax, %rbx
            sub    %rcx, %rbx     # neg_count = initial_len - pos_count - zero_count
    

    If branches (esp. unpredictable branches) are expensive, this version will do much better. Using Jester's suggestion of calculating one of the counts from the other two helped a lot in this case.

    There's pretty good instruction-level parallelism here. The two separate setcc -> add dependency chains can run in parallel once the test result is ready.