Search code examples
assemblyrecursion64-bitx86-64att

Recursive factorial subroutine in x64 assembly gives stack overflow


I am implementing a recursive algorithm to calculate the factorial of a given number in x64 assembly. (I am using the AT&T syntax.)

The pseudocode looks like this:

int factorial(int x){
    if(x == 1){
        return 1;
    }else{
        return x*factorial(x-1);
    }
 }

Now, my implementation in x64 assembly:

factorial:  
    #Start of subroutine
    pushq %rbp
    movq %rsp, %rbp 

    cmpq $1, %rdi   #If rdi == 1, return 1, otherwise return x*factorial(x-1)
    je if           #je : %rdi == 1

    jmp else

if:                 #return 1
    movq $1, %rax
    jmp factend

else:               #return x*factorial(x-1)
    pushq %rdi      #Save x
    subq $1,%rdi    #Calculate x-1

    call factorial  #Calculate factorial from x-1
    popq %rdi       #Get rdi value before decrement
    mulq %rdi       #Multiply with rax, rax is either 1, or result of previous multiplication

    jmp factend     #End this subroutine, continue with previous 

factend:
    #End of subroutine
    movq %rbp, %rsp
    popq %rbp
    ret

This implementation doesn't stop, however. I get a segmentation fault, which is caused by a stack overflow. The if block is never executed, and the subroutine is stuck in a loop with the else code. If I follow my implementation step by step and write down the values of registers and the stack, I don't seem to run in to a problem. What might cause this?

EDIT How I retrieve the input value:

formatinput: .asciz "%d"

#Get input from terminal
subq $8, %rsp
leaq -8(%rbp), %rsi
movq $0, %rax
movq $formatinput, %rdi
call scanf

#Calculate factorial of input value
movq -8(%rbp), %rdi
movq $1, %rax
call factorial

Another EDIT

My complete code:

#Define main
.global main

.global inout

#Define string to be printed
formatinput: .asciz "%d"
formatoutput: .asciz "%d\n"

str:    .asciz "Assignment %d"


main:


#Start of program   
    movq %rsp, %rbp

    #Print statement
    movq $0, %rax
    movq $4, %rsi
    movq $str, %rdi
    call printf

    call inout

end:
    #Exit program with code 0, no errors
    movq $0, %rdi
    call exit

#inout subroutine
inout:  
    #Start of subroutine
    pushq %rbp
    movq %rsp, %rbp

    #Get input from terminal
    subq $8, %rsp
    leaq -8(%rbp), %rsi
    movq $0, %rax
    movq $formatinput, %rdi
    call scanf

    #Calculate factorial of input value
    movq -8(%rbp), %rdi
    movq $1, %rax
    call factorial
    movq %rax, -8(%rbp)

    #Print result
    movq $0, %rax
    movq -8(%rbp), %rsi
    movq $formatoutput, %rdi
    call printf

    movq %rbp, %rsp
    popq %rbp
    ret

#factorial subroutine
# int factorial(int x){
#   if(x == 1){
#       return 1;
#   }else{
#       return x*factorial(x-1);
#   }
# }

factorial:  
    #Start of subroutine
    pushq %rbp
    movq %rsp, %rbp 

    cmpq $1, %rdi       #If rdi == 1, return 1, otherwise return x*factorial(x-1)
    jg if           #jg : %rdi > $1

    jmp else

if:             #return 1
    movq $1, %rax
    jmp factend

else:               #return x*factorial(x-1)
    pushq %rdi      #Save x
    subq $1,%rdi        #Calculate x-1

    call factorial      #Calculate factorial from x-1
    popq %rdi       #Get rdi value before decrement
    mulq %rdi       #Multiply with rax, rax is either 1, or result of previous multiplication

    jmp factend     #End this subroutine, continue with previous 

factend:
    #End of subroutine
    movq %rbp, %rsp
    popq %rbp
    ret

#print test subroutine
print:
    #Start of subroutine
    pushq %rbp
    movq %rsp, %rbp

    pushq %rdi
    pushq %rsi
    pushq %rax

    movq $0, %rax
    movq %rdi, %rsi
    movq $formatoutput, %rdi
    call printf

    popq %rsi
    popq %rdi
    popq %rax

    movq %rbp, %rsp
    popq %rbp
    ret

Solution

  • You're using 64-bit integers, while your c-code uses int, which is usually 32-bit. So your scanf("%d") doesn't touch the upper 32 bits of the value you load into %rdi to pass to factorial().
    Whatever was in those upper bits before the scanf() is now interpreted as part of the number you pass, so instead of an input like 1, factorial() interprets it as something like 18612532834992129, which causes stack overflow.

    You could either replace the movq -8(%rbp), %rdi after scanf() with movl -8(%rbp), %edi, or change the scanf()-format specifier from %d to %ld.

    The movl-variant shows an interesting tidbit about x86-64: Using a 32-bit operation implicitly clears the upper 32-bits of the 64-bit register (the exception being xchg %eax, %eax, since this one is the canonical nop).