Search code examples
sortingassemblyx86nasm

Trying to understand the new sorting algorithm from AlphaDev: why does my assembly code not work as expected?


There is a recent publication at nature.com, Faster sorting algorithms discovered using deep reinforcement learning, where it talks about AlphaDev discovering a faster sorting algorithm. This caught my interest, and I've been trying to understand the discovery.

Other articles about the topic are:

Here is the pseudo-code of the original sort3 algorithm against the improved algorithm that AlphaDev discovered.

enter image description here

Original Pseudo-Code

Memory [0] = A
Memory [1] = B
Memory [2] = C

mov Memory[0] P  // P = A
mov Memory[1] Q  // Q = B
mov Memory[2] R  // R = C

mov R S
cmp P R
cmovg P R  // R = max(A, C)
cmovl P S  // S = min(A, C)
mov S P    // P = min(A, C)
cmp S Q
cmovg Q P  // P = min(A, B, C)
cmovg S Q  // Q = max(min(A, C), B)

mov P Memory[0]  // = min(A, B, C)
mov Q Memory[1]  // = max(min(A, C), B)
mov R Memory[2]  // = max(A, C)

AlphaDev Pseudo-Code

Memory [0] = A
Memory [1] = B
Memory [2] = C

mov Memory[0] P  // P = A
mov Memory[1] Q  // Q = B
mov Memory[2] R  // R = C

mov R S
cmp P R
cmovg P R  // R = max(A, C)
cmovl P S  // S = min(A, C)

cmp S Q
cmovg Q P  // P = min(A, B)
cmovg S Q  // Q = max(min(A, C), B)

mov P Memory[0]  // = min(A, B)
mov Q Memory[1]  // = max(min(A, C), B)
mov R Memory[2]  // = max(A, C)

The improvement centers around the omission of the single move command, mov S P. To help understand, I wrote the following assembly code. However, my testing shows that the sorting algorithm does not work when A=3, B=2, and C=1, but it does work when A=3, B=1, and C=2.

This is written, compiled, and run on Ubuntu 20.04 Desktop.

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.6 LTS
Release:    20.04
Codename:   focal
$ nasm -v
NASM version 2.14.02
$ ld -v
GNU ld (GNU Binutils for Ubuntu) 2.34

My assembly code test...

; -----------------------------------------------------------------
;
; sort_test.asm
;
; Test for AlphaDev sorting algorithm
;
; My findings are that AlphaDev's removal of 'mov S P' doesn't work when:
;   a = 3, b = 2, c = 1
; But it does work with:
;   a = 3, b = 1, c = 2
;
; Output: The sorted values of a, b & c printed to stdout with spaces
;
; Compile & run with:
;
; nasm -f elf32 sort_test.asm && ld -m elf_i386 sort_test.o -o sort_test && ./sort_test
;
; -----------------------------------------------------------------

global _start

section .data
  a equ 3
  b equ 2
  c equ 1

section .bss
  buffer resb 5

section .text
_start:
; ------------------- ; AlphaDev pseudo-code

  mov eax, a          ; P = A
  mov ecx, b          ; Q = B
  mov edx, c          ; R = C
  mov ebx, edx        ; mov R S

  cmp eax, edx        ; cmp P R
  cmovg edx, eax      ; cmovg P R  // R = max(A, C)
  cmovl ebx, eax      ; cmovl P S  // S = min(A, C)

; The following line was in original sorting algorithm,
; but AlphaDev determined it wasn't necessary
;  mov eax, ebx       ; mov S P   // p = min(A, C)

  cmp ebx, ecx        ; cmp S Q
  cmovg eax, ecx      ; cmovg Q P  // P = min(A, B)
  cmovg ecx, ebx      ; cmovg S Q  // Q = max(min(A, C), B)

; add values to buffer with spaces
  add eax, 30h
  mov [buffer], eax
  mov [buffer+1], byte 0x20

  add ecx, 30h
  mov [buffer+2], ecx
  mov [buffer+3], byte 0x20

  add edx, 30h
  mov [buffer+4], edx

; write buffer to stdout
  mov eax, 4      ; sys_write system call
  mov ebx, 1      ; stdout file descriptor
  mov ecx, buffer ; buffer to write
  mov edx, 5      ; number of bytes to write
  int 0x80

  mov eax, 1      ; sys_exit system call
  mov ebx, 0      ; exit status 0
  int 0x80

I've run this test on the command line to print the results of the sort, but I also used gdb to step through this executable line-by-line. During this debugging, I clearly see that the register for "A", aka "P", aka "eax", is never updated when A=3, B=2, and C=1 but is updated when A=3, B=1, and C=2.

Full disclosure... I'm not an assembly programmer. I'm also not proficient in any other specific language, but I've experimented with C, C++, Javascript, PHP, & HTML to get small projects done. Basically, I'm self taught on what I do know. To get to the point to write this test, I've had to learn quite a bit. Therefore, I could certainly be making mistakes or not understanding the problem.

Anyway, please help me understand why I'm observing what I am.

  • Am I misunderstanding the problem?
  • Am I misunderstanding the pseudo-code?
  • Am I making a mistake transforming the pseudo-code into assembly?
  • Is there a mistake with my assembly code?
  • Is the pseudo-code wrong?

Solution

  • TL:DR: they're confusingly only showing the last 2 of 3 comparators in a 3-element sorting network, not a complete 3-element sort. This is presented very misleadingly, including in the diagram in their paper.


    I'd have used AT&T syntax (like cmovg %ecx, %eax in a .s file assembled with GCC) so the operand order can match the pseudocode, destination on the right.

    You're correct, I had a look at the article and the 3-element pseudocode doesn't sort correctly when C is the smallest element. I know x86-64 asm backwards can forwards, and I don't just mean Intel vs. AT&T syntax :P Even looking at the real code, not just the comments, there's no way for the smallest element to end up in memory[0] = P if it started in R = memory[2] = C.

    I opened the article before really reading what your question was asking, and noticed that problem myself after skimming the article until getting to the part about the actual improvement, so I haven't looked at your attempt to replicate it. But I didn't have any bias towards seeing a problem in it, I just wanted to understand it myself. There aren't any instructions writing P that read from values that could contain the starting R value, so there's no way it can get that value.


    The article indirectly links their paper published in Nature (Faster sorting algorithms discovered using deep reinforcement learning by Daniel J. Mankowitz, et. al.) The full text is there in the Nature link.

    They use the same image of code in the actual paper, but with some explanatory text and diagram in terms of a 3-element sorting network.

    image

    Figure 3a presents an optimal sorting network for three elements (see Methods for an overview of sorting networks). We will explain how AlphaDev has improved the circled network segment. There are many variants of this structure that are found in sorting networks of various sizes, and the same argument applies in each case.

    The circled part of the network (last two comparators) can be seen as a sequence of instructions that takes an input sequence ⟨A, B, C⟩ and transforms each input as shown in Table 2a (left). However, a comparator on wires B and C precedes this operator and therefore input sequences where B ≤ C are guaranteed. This means that it is enough to compute min(A, B) as the first output instead of min(A, B, C) as shown in Table 2a (right). The pseudocode difference between Fig. 3b,c demonstrates how the AlphaDev swap move saves one instruction each time it is applied.

    So this pseudocode is just for the circled part of the sorting network, the last 2 of 3 compare-and-swap steps. In their blog article, and even in other parts of the paper like Table 2, they make it sound like this is the whole sort, not just the last 2 steps. The pseudocode even confusingly starts with values in memory, which wouldn't be the case after conditionally swapping B and C to ensure B <= C.


    Also, it's unlikely just a mov instruction is a huge speedup in a 3-element sort. Can x86's MOV really be "free"? Why can't I reproduce this at all? - it's never free (it costs front-end bandwidth), but it has zero latency on most recent microarchitectures other than Ice Lake. I'm guessing this wasn't the case where they got a 70% speedup!


    With AVX SIMD instructions like vpminsd dst, src1, src2 (https://www.felixcloutier.com/x86/pminsd:pminsq) / vpmaxsd to do min and max of Signed Dword (32-bit) elements with a non-destructive separate destination, there's no saving except critical-path latency. min(B, prev_result) is still just one instruction, no separate register-copy (vmovdqa xmm0, xmm1) needed like it could be with just SSE4.1 if you were doing a sorting-network. But latency could perhaps be significant when building a sorting network out of shuffles and SIMD min/max comparators, which last I heard was the state of the art in integer sorting for large integer or FP arrays on x86-64, not just saving a mov in scalar cmov code!

    But lots of programs are compiled not to assume AVX is available, because unfortunately it's not universally supported, missing on some low-power x86 CPUs from as recently as the past couple years, and on Pentium / Celeron CPUs before Ice Lake (so maybe as recent as 2018 or so for low-budget desktop CPUs.)

    Their paper in Nature mentions SIMD sorting networks, but points out that libc++ std::sort doesn't take advantage of it, even for the case where the input is an array of float or int, rather than classes with an overloaded operator <.


    This 3-element tweak is a micro-optimization, not a "new sorting algorithm". It might still save latency on AArch64, but only instructions on x86

    It's nice that AI can find these micro-optimizations, but they wouldn't be nearly as surprising if presented as having a choice between selecting from min(A,C) or min(B,C) because the latter is what B actually is at that point.

    Avoiding register-copy instructions with careful choice of higher-level source is something humans can do, e.g. the choice of _mm_movehl_ps merge destination (first source operand) in my 2016 answer on Fastest way to do horizontal SSE vector sum (or other reduction) - see the comment on the compiler-generated asm # note the reuse of shuf, avoiding a movaps.

    Previous work in automated micro-optimization includes STOKE, a stochastic superoptimizer that randomly tries instruction sequences hoping to find cheap sequences that match the outputs of a test function you give it. The search space is so large that it tends to miss possible sequences when it takes more than 3 or 4 instructions (STOKE's own page says it's not production-ready, just a research prototype). So AI is helpful. And it's a lot of work to look at asm by hand for possible missed optimizations that could be fixed by tweaking the source.

    But at least for this 3-element subproblem, it is just a micro-optimization, not truly algorithmically new. It's still just a 3-comparator sorting network. One that compiles more cheaply for x86-64, which is nice. But on some 3-operand ISAs with a separate destination for their equivalent of cmov, like AArch64's csel dst, src1, src2, flag_condition conditional-select, there's no mov to save. It could still save latency on the critical path, though.

    Their paper in Nature also shows an algorithmic difference for sorting a variable number of elements, where the >= 3 cases both start by sorting the first 3. Maybe this helps branch prediction since that work can be in flight while a final branch on len > 3 is resolving to see whether they need to do a simplified 4-element sort that can assume the first 3 elements are sorted. They say "It is this part of the routine that results in significant latency savings." (They also call this a "fundamentally new" algorithm, which I presume is true for the problem of using sorting networks on short unknown-length inputs.)