I want to write a program in RISC-V that computes the following function on strictly positive integers:
f(n) = f(n/2) if n is even
f(3n+1) if n!=1 is odd
1 if n=1
We assume that a strictly positive integer is initially stored in register a0
, as input to the program.
I came up with the following code:
li t1 1
li t2 2
li t3 3
test:
andi t0 a0 1
beqz t0, iseven
isodd:
beq a0, t1, exit
mul a0, a0, t3
add a0, a0, t1
j test
iseven:
div a0 a0 t2
j test
exit:
Maybe it's correct, but it definitely seems to be awkward. I suspect there is a more elegant way to write that. For instance, the three load immediate at the beginning looks ugly. Also the conditional branches look ugly.
How can I improve the code above? I'm an absolutely beginner in RISC-V.
This is a function for testing the https://en.wikipedia.org/wiki/Collatz_conjecture - which is unproven but nobody's ever found a number that doesn't eventually converge to 1
when applying these steps repeatedly.
The definition you were given is recursive, but you've correctly implemented it as an iterative loop because that's much more efficient in assembly. But you forgot the loop exit condition (the base case). Notice how it's 1
if n=1
, not f(1)
: you don't keep looping then.
Your code is fairly clean (and well-formatted, although I'd indent the operands to a consistent column). As you suspect, it's possible (and more efficient in this case) with fewer constants set up in registers ahead of the loop.
I found the iseven
label name reads like "i seven", so I changed to is_even
and is_odd
. If you like shorter label names, odd
and even
work. (Or .Lodd
and .Leven
to make them "local" labels in GNU assembler, no metadata in the .o
).
You don't need a div
, just right-shift by 1. (Related: Why does C++ code for testing the Collatz conjecture run faster than hand-written assembly? for x86). Division is slow, often very slow. Bit-shifts are fast and can use the shift-count as an immediate operand. (As I mentioned in that Q&A, optimized compiler output (Godbolt) can be useful to look at for simple-enough functions. GCC and clang both only set up a single constant, 1
. That function returns an iteration count; yours just returns. (Or if a number didn't converge, would loop forever). You can try different optimization levels like -O1
or -Og
to see less mixing together of work from different statements.)
Similarly, you don't need a multiply; a compiler would typically do 3*n + 1
as (n+1) + (n<<1)
, so three instructions instead of 2 but no mul
and still only critical-path latency of 2 ALU operations (since there's instruction-level parallelism.) Depending on the microarchitecture, mul
can be pretty cheap; unlike division there's a lot of inherent parallelism in adding up partial products, so throwing a lot of transistors at the problem can get it done fast. But at least on CPUs with high clock speeds, it usually has more than 2 cycle latency. 3 is common in modern x86.
So that gets rid of the 2
and 3
constants, and you can use addi
like you're doing with andi
.
You do need 1
in a register for the loop-exit condition, though, to detect when the sequence has converged to 1
. RISC-V compare-and-branch instructions use all their bits on a relative displacement, no bits for an immediate to compare against.
You have an unconditional j
at the bottom; if you arrange your loop differently, you can put the loop-exit condition there so you don't also need another branch inside the loop for that. (Why are loops always compiled into "do...while" style (tail jump)?)
.global collatz_test
collatz_test: # function entry point
li t1, 1
beq a0, t1, exit # or j to the loop entry-point.
# or just fall into the loop and let it do 3*1 + 1 = 4
# then reach 1 with 2 more iterations.
testloop: # do {
andi t0, a0, 1
beqz t0, is_even
is_odd:
slli t2, a0, 1 # n*2
addi a0, a0, 1 # n+1
add a0, a0, t2 # n*2 + n+1 = n*3 + 1
# mul a0, a0, t3
# add a0, a0, t1
#j testloop
# this can't produce a result of 1
# even if the n*=3 overflows, it can't clear the lowest set bit since 3 is odd.
# so non-zero n already has n*3 != 0
# actually 3n+1 is always even since we only get here with odd n
# so we can just fall through into is_even
is_even: # this *can* produce a 1 we need to check for
slri a0, a0, 1
bne a0, t1, testloop # }while(n != 1)
exit:
# ret # if this was a function, not a whole program
So this is a do{}while(n != 1)
loop. It has two branches back to the top of the loop, like loop tail-duplication optimization, but in one of them I optimized away the n != 1
loop condition so it's just a j
. Otherwise I'd have had to do
# tail-duplication *without* being able to skip the loop-exit condition in one half
is_odd:
...
add ...
bne a0, t1, testloop # }while(n != 1)
j exit # don't fall through into the other side of the if
is_even:
...
Or as @Andrey Turkin commented, we can go even farther towards optimizing for this specific loop and remove the j testloop
branch for the is_odd
case: 3*n + 1
is always even for odd n
, so the next loop iteration would jump to is_even
. We can go there directly without another loop iteration, as well as without checking for 1
. (The reasoning in comments about 3*n + 1
always being non-1 was superfluous; that block only runs with odd n
anyway so the argument is simpler.)