Search code examples
assemblyriscvcollatz

Programming with RISC-V: how to write cleaner, less ugly code for the Collatz conjecture?


I want to write a program in RISC-V that computes the following function on strictly positive integers:

f(n) = f(n/2) if n is even
        f(3n+1) if n!=1 is odd
        1 if n=1

We assume that a strictly positive integer is initially stored in register a0, as input to the program.

I came up with the following code:

 li t1 1
 li t2 2
 li t3 3

test:
 andi t0 a0 1 
 beqz t0, iseven

isodd:
 beq a0, t1, exit
 mul a0, a0, t3
 add a0, a0, t1
 j test

iseven:
 div a0 a0 t2
 j test

exit:

Maybe it's correct, but it definitely seems to be awkward. I suspect there is a more elegant way to write that. For instance, the three load immediate at the beginning looks ugly. Also the conditional branches look ugly.

How can I improve the code above? I'm an absolutely beginner in RISC-V.


Solution

  • This is a function for testing the https://en.wikipedia.org/wiki/Collatz_conjecture - which is unproven but nobody's ever found a number that doesn't eventually converge to 1 when applying these steps repeatedly.

    The definition you were given is recursive, but you've correctly implemented it as an iterative loop because that's much more efficient in assembly. But you forgot the loop exit condition (the base case). Notice how it's 1 if n=1, not f(1) : you don't keep looping then.


    Your code is fairly clean (and well-formatted, although I'd indent the operands to a consistent column). As you suspect, it's possible (and more efficient in this case) with fewer constants set up in registers ahead of the loop.

    I found the iseven label name reads like "i seven", so I changed to is_even and is_odd. If you like shorter label names, odd and even work. (Or .Lodd and .Leven to make them "local" labels in GNU assembler, no metadata in the .o).

    You don't need a div, just right-shift by 1. (Related: Why does C++ code for testing the Collatz conjecture run faster than hand-written assembly? for x86). Division is slow, often very slow. Bit-shifts are fast and can use the shift-count as an immediate operand. (As I mentioned in that Q&A, optimized compiler output (Godbolt) can be useful to look at for simple-enough functions. GCC and clang both only set up a single constant, 1. That function returns an iteration count; yours just returns. (Or if a number didn't converge, would loop forever). You can try different optimization levels like -O1 or -Og to see less mixing together of work from different statements.)

    Similarly, you don't need a multiply; a compiler would typically do 3*n + 1 as (n+1) + (n<<1), so three instructions instead of 2 but no mul and still only critical-path latency of 2 ALU operations (since there's instruction-level parallelism.) Depending on the microarchitecture, mul can be pretty cheap; unlike division there's a lot of inherent parallelism in adding up partial products, so throwing a lot of transistors at the problem can get it done fast. But at least on CPUs with high clock speeds, it usually has more than 2 cycle latency. 3 is common in modern x86.

    So that gets rid of the 2 and 3 constants, and you can use addi like you're doing with andi.

    You do need 1 in a register for the loop-exit condition, though, to detect when the sequence has converged to 1. RISC-V compare-and-branch instructions use all their bits on a relative displacement, no bits for an immediate to compare against.

    You have an unconditional j at the bottom; if you arrange your loop differently, you can put the loop-exit condition there so you don't also need another branch inside the loop for that. (Why are loops always compiled into "do...while" style (tail jump)?)

    .global collatz_test
    collatz_test:            # function entry point
     li   t1, 1
     beq  a0, t1, exit       # or  j to the loop entry-point.
                             # or just fall into the loop and let it do 3*1 + 1 = 4
                              # then reach 1 with 2 more iterations.
    
    testloop:              # do {
     andi   t0, a0, 1 
     beqz   t0, is_even
    
    is_odd:
     slli   t2, a0, 1        # n*2
     addi   a0, a0, 1        # n+1
     add    a0, a0, t2       # n*2 + n+1 = n*3 + 1
          # mul  a0, a0, t3
          # add  a0, a0, t1
     #j testloop
                             # this can't produce a result of 1
                             # even if the n*=3 overflows, it can't clear the lowest set bit since 3 is odd.
                             # so non-zero n already has n*3 != 0
         # actually 3n+1 is always even since we only get here with odd n
         # so we can just fall through into is_even
    
    is_even:            # this *can* produce a 1 we need to check for
     slri   a0, a0, 1
     bne    a0, t1, testloop # }while(n != 1)
    
    exit:
    # ret   # if this was a function, not a whole program
    

    So this is a do{}while(n != 1) loop. It has two branches back to the top of the loop, like loop tail-duplication optimization, but in one of them I optimized away the n != 1 loop condition so it's just a j. Otherwise I'd have had to do

    # tail-duplication *without* being able to skip the loop-exit condition in one half
    is_odd:
     ...
       add ...
       bne  a0, t1, testloop      # }while(n != 1)
       j    exit                  # don't fall through into the other side of the if
    
    is_even:
      ...
    

    Or as @Andrey Turkin commented, we can go even farther towards optimizing for this specific loop and remove the j testloop branch for the is_odd case: 3*n + 1 is always even for odd n, so the next loop iteration would jump to is_even. We can go there directly without another loop iteration, as well as without checking for 1. (The reasoning in comments about 3*n + 1 always being non-1 was superfluous; that block only runs with odd n anyway so the argument is simpler.)