Search code examples
assemblyx86masmmasm32

Looking for improvement to a procedure to determine a sorted array's median


Function parameters are the sorted array and the length of the array. The goal is to determine the median of an odd or even length array.

Odd length arrays are handled simply by determining the exact middle element, even length arrays are handled by getting the two elements that "straddle" the midpoint and averaging them.

The question is: (after the even_: label) I had to repetitively determine the left and right of the straddle values in the manner that you see.

At the line mov eax, [edi+eax-4], I can manipulate this with differing multiples of 4 and get any index position value I want. However, if I immediately follow the instruction mov eax, [edi+eax-4] with mov esi, [edi+eax +/- any multiple of 4], I always get "0" (esi arbitrarily selected).

So, is the way I did it the best way or am I lacking some wisdom on how to access two array elements in one go, so to speak?

GetMedian   PROC
    push ebp
    mov ebp, esp
    mov eax, [ebp+12]           ; eax = length of array.
    mov ebx, 2
    cdq
    div ebx                     ; eax = Length of array/2.
    cmp edx,0
    je even_                    ; Jump to average the straddle.
    mov ebx, TYPE DWORD
    mul ebx                     ; eax now contains our target index.
    mov edi, [ebp+8]
    mov eax, [edi+eax]          ; Access array[eax].
    jmp toTheEnd
even_:
    mov ebx, TYPE DWORD
    mul ebx                     ; eax now contains our target index.
    mov edi, [ebp+8]            ; edi now contains @array[0].
    mov eax, [edi+eax-4]        ; Dereferences array[left] so a value is in eax.
    mov esi, eax                ; save eax (value left of straddle).
    mov eax, [ebp+12]           ; eax = length of array.
    mov ebx, 2
    cdq
    div ebx
    mov ebx, TYPE DWORD
    mul ebx                     ; eax now contains our target index.
    mov edi, [ebp+8]
    mov eax, [edi+eax]          ; Access array[right] (value right of straddle).
    add eax, esi                ; list[eax-1] + list[eax].
    mov ebx, 2
    cdq
    div ebx
toTheEnd:
    pop ebp
    ret 12
GetMedian ENDP

Solution

  • BTW, your code doesn't actually work: mov ebx, 2 clobbers ebx, but you don't save/restore it. So you've stepped on a register that is call-preserved in all the usual ABIs / calling conventions. See the tag wiki.

    Also, I think ret 12 should be ret 8, since you take two 4 byte args. (See below).


    Here's an interesting idea: branchless by always adding two elements. For an odd-length array, it's the same two elements. For an even-length array, it's the middle-round-down and middle-round-up.

    If your code actually has the same array length repeatedly, so a branch will predict well, a conditional branch will probably be better (on test ecx, 1 / jnz odd, or jc after shift). Esp. if odd lengths are the common case. Sometimes it's worth doing something unconditionally, even if it's not always needed.

    ; Untested
    GetMedian   PROC
        ;; return in eax.  clobbers: ecx, edx (which don't need to be saved/restored)
        mov   ecx, [esp+8]            ; ecx = unsigned len
        mov   edx, [esp+4]            ; edx = int *arr
        shr   ecx                     ; ecx = len/2.  CF = the bit shifted out. 0 means even, 1 means odd
    
        mov   eax, [edx + ecx*4]      ; eax = arr[len/2]
        sbb   ecx, -1                 ; ecx += 1 - CF.  
        add   eax, [edx + ecx*4]      ; eax += arr[len/2 + len&1]
    
        shr   eax, 1                  ; eax /= 2  (or sar for arithmetic shift)
        ret 12    ;;; Probably a bug
    GetMedian ENDP
    ;; 5 instructions, plus loading args from the stack, and the ret.
    

    I left off the instructions to make a stack frame, because this is a leaf function with no need for any local storage. Using ebp doesn't make anything easier or help with backtraces, and is a waste of instructions.

    For most conditions, you have to use setcc to get a 0 or 1 in a register based on the flag. But CF is special. add-with-carry and sub-with-borrow use it (which I take advantage of here), and so do the rotate-through-carry instructions. It's more common to adc reg, 0, but I needed the inverse, and came up with sbb reg, -1 to add 0 or 1 depending on CF.

    Are you sure ret 12 is right? Your 2 args are only 8 bytes. ret imm16 adds the immediate to esp after popping the return address, so the count is the total change to the stack pointer due to the call/ret pair.

    Also, I assume that adding two elements won't wrap (carry or overflow), even when it's the middle element of an odd-length array.


    Or, another branchless approach which is probably worse

    ; Untested
    ; using cmov on two loads, instead of sbb to make the 2nd load address dependent on CF
    GetMedian   PROC
        mov   ecx, [esp+8]            ; ecx = unsigned len
        mov   edx, [esp+4]            ; edx = int *arr
        shr   ecx, 1                  ; ecx = len/2.  CF = the bit shifted out. 0 means even, 1 means odd
    
        mov   eax, [edx + ecx*4]      ; eax = arr[len/2]
        mov   edx, [edx + ecx*4 + 4]  ; edx = arr[len/2+1]  (reads past the end if len=0, and potentially touches a different cache line than len/2)
        cmovc edx, eax                ; CF still set from shr.  edx = odd ? arr[len/2] : edx
    
        add   eax, edx
        shr   eax, 1                  ; eax /= 2  (or sar for arithmetic shift)
        ret 8
    GetMedian ENDP
    

    Branching implementation:

    This is probably more like what you'd get from a C compiler, but some compiler might not be smart enough to branch on CF as set by the shift. I wouldn't be surprised either way, though; I think I've seen gcc or clang branch on flags set by shifts.

    ; Untested
    GetMedian   PROC
        ;; return in eax.  clobbers: ecx, edx (which don't need to be saved/restored)
        mov   ecx, [esp+8]            ; ecx = unsigned len
        mov   edx, [esp+4]            ; edx = int *arr
        shr   ecx                     ; ecx = len/2.  CF = the bit shifted out. 0 means even, 1 means odd
    
        mov   eax, [edx + ecx*4]      ; eax = arr[len/2]
    
        jc   @@odd   ; conditionally skip the add and shift
        add   eax, [edx + ecx*4 + 4]  ; eax += arr[len/2 + 1]
    
        shr   eax, 1                  ; eax /= 2  (or sar for arithmetic shift)
    @@odd:  ;; MASM local label, doesn't show up in the object file
        ret 8
    GetMedian ENDP
    

    Alternatively:

        jnc   @@even
        ret   8         ; fast-path for the odd case
    @@even:  ;; MASM local label, doesn't show up in the object file
        add   eax, [edx + ecx*4 + 4]  ; eax += arr[len/2 + len&1]
    
        shr   eax, 1                  ; eax /= 2  (or sar for arithmetic shift)
        ret 8   ; duplicate whole epilogue here: any pop or whatever
    

    Play with the scale factor instead of shifting:

    Mask off the low bit of len, and then use arr[len/2] = [edx + (len/2)*4] = [edx + len*2]

    This shortens the dependency chain from len to result by one shr, but it means the first load has to come after the branch. (And without tail-duplication (separate rets), we'd need an unconditional branch somewhere to implement the if(odd){}else{} structure instead of the simpler load; if(even){}; ret structure.)

    ; Untested
    GetMedian   PROC
        ;; return in eax.  clobbers: ecx, edx (which don't need to be saved/restored)
        mov   ecx, [esp+8]            ; ecx = unsigned len
        mov   edx, [esp+4]            ; edx = int *arr
        test  ecx, 1
        jz   @@even
        mov   eax, [edx + ecx*2 - 2]  ; odd
        ret   8
    @@even:
        mov   eax, [edx + ecx*2]
        add   eax, [edx + ecx*2 + 4]
        shr   eax, 1
        ret 8
    GetMedian ENDP