Search code examples
gogoroutine

Stop all recursive functions in a goroutine


Starting a goroutine which runs a recursive function, I want to send a signal to stop those recursive functions. This is the function (the functionality is not important):

func RecursiveFunc(x int, depth int, quit chan bool) int {

    if depth == 0 {
        return 1
    }

    if quit != nil {
        select {
        case <-quit:
            return 0
        default:
        }
    }

    total := 0

    for i := 0; i < x; i++ {

        y := RecursiveFunc(x, depth - 1, quit)

        if y > 0 {
            total += y
        }

    }

    return total
}

This function may take a long time to be done and I want stop it after sending a quit signal and use the result (whatever it is). To run it:

import (
    "fmt"
    "time"
    "sync"
)

func main() {

    quit := make(chan bool)
    wg := &sync.WaitGroup{}
    result := -1

    go func() {
        defer wg.Done()
        wg.Add(1)
        result = RecursiveFunc(5, 20, quit)
    }()

    time.Sleep(10 * time.Millisecond)

    close(quit) // Using `quit <- true` doesn't work

    wg.Wait()

    fmt.Println(result)
}

To stop the goroutine, I'm using a channel say quit and after closing it, the program works well, however I don't want really close the channel and I want just send a signal quit <- true. However, quit <- true doesn't work and I probably quits only one instance of recursion.

How can I stop all instances of recursive function by sending a quit signal?


Solution

  • You can do the what you are going to do using context.

    You can pass a context.Context object as the first parameter to the function which you need to stop from outside, and call the corresponding cancel function to send a "cancellation signal" to the function, which will cause the Done() channel of the context.Context to be closed, and the called function will thus be notified of the cancellation signal in a select statement.

    Here is how the function handles the cancellation signal using context.Context:

    func RecursiveFunc(ctx context.Context, x int, depth int) int {
    
        if depth == 0 {
            return 1
        }
    
        select {
        case <-ctx.Done():
            return 0
        default:
        }
    
        total := 0
    
        for i := 0; i < x; i++ {
    
            y := RecursiveFunc(ctx, x, depth-1)
    
            if y > 0 {
                total += y
            }
    
        }
    
        return total
    }
    

    And here is how you can call the function with the new signature:

    func main() {
    
        wg := &sync.WaitGroup{}
        result := -1
    
        ctx, cancel := context.WithCancel(context.Background())
    
        go func() {
            defer wg.Done()
            wg.Add(1)
            result = RecursiveFunc(ctx, 5, 20)
        }()
    
        time.Sleep(10 * time.Millisecond)
    
        cancel()
    
        wg.Wait()
    
        fmt.Println(result)
    }