Search code examples
goconcurrencygoroutine

[Golang]: Goroutine cancelation mid-execution


I was watching the Rob Pike's fascinating talk called "Concurrency is not Parallelism" and I was trying to extend the example with concurrent queries, but it seems that I don't fully understand how goroutines work. In particular, I don't understand how to cancel the goroutine during its execution.

Here is a dummy connection that can perform a query (suppose it is unchangeable, like an external API):

type Conn struct {
    id    int
    delay time.Duration
}

type Result int

func (c *Conn) DoQuery(query string) Result {
    defer fmt.Printf("DoQuery call for connection %d finished\n", c.id)

    time.Sleep(c.delay)
    return Result(len(query))
}

Below is the Query function that sends a query to each connection. When it gets the first result, it supposed to cancel all the other calls (as a teardown), but it does not. When the goroutine starts the query, it looses the ability to watch the ctx.Done() channel, because the conn.DoQuery(...) call is blocking.

func Query(conns []Conn, query string) Result {
    results := make(chan Result, len(conns))

    ctx, cancel := context.WithCancel(context.Background())
    defer cancel()

    for _, conn := range conns {
        go func() {
            defer fmt.Printf("goroutine for connection %d finished\n", conn.id)

            select {
            case <-ctx.Done():
                return
            default:
                // While the goroutine computes the result of this call
                // it does not watch the ctx.Done() channel, so the call
                // cannot be canceled mid-execution.
                results <- conn.DoQuery(query)
            }
        }()
    }

    return <-results
}

Here is also the main function that ties everything together (in case you want to run the example):

func main() {
    conns := []Conn{
        {id: 1, delay: 1 * time.Second},
        {id: 2, delay: 3 * time.Second},
        {id: 3, delay: 5 * time.Second},
        {id: 4, delay: 4 * time.Second},
    }

    start := time.Now()
    result := Query(conns, "select count(*) from users;")
    duration := time.Since(start)

    fmt.Println(result, duration)

    for {}
}

So my question is: what is the correct way to implement the Query function to achieve the desired behavior?


Solution

  • Your Query function has to support cancellation (Go Playground):

    func (c *Conn) DoQuery(ctx context.Context, query string) (Result, error) {
        defer fmt.Printf("DoQuery call for connection %d finished\n", c.id)
    
        t := time.NewTimer(c.delay)
        defer t.Stop()
        select {
        case <-t.C:
            return Result(c.id), nil
    
        case <-ctx.Done():
            return Result(0), ctx.Err()
        }
    }
    

    That would be the case for a network call or a database query.

    Then, you shouldn't leak goroutines:

    func Query(ctx context.Context, conns []Conn, query string) (Result, error) {
        ctx, cancel := context.WithCancel(ctx)
        defer cancel()
    
        var first atomic.Bool
        var result Result
        var err error
    
        setFirst := func(r Result, e error) {
            if first.CompareAndSwap(false, true) {
                result, err = r, e
                cancel()
            }
        }
    
        var wg sync.WaitGroup
        for _, conn := range conns {
            wg.Add(1)
            go func() {
                defer wg.Done()
                defer fmt.Printf("goroutine for connection %d finished\n", conn.id)
    
                setFirst(conn.DoQuery(ctx, query))
            }()
        }
    
        wg.Wait()
    
        return result, err
    }
    

    I'm using an atomic.Bool here for synchronization, since you are only interested in the first result. Feel free to use a mutex or whatever fits to your problem.

    You would call this with a context too, derived from the root context:

    func main() {
        ctx := context.Background()
        conns := []Conn{
            {id: 1, delay: 1 * time.Second},
            {id: 2, delay: 3 * time.Second},
            {id: 3, delay: 5 * time.Second},
            {id: 4, delay: 4 * time.Second},
        }
    
        start := time.Now()
        result, err := Query(ctx, conns, "select count(*) from users;")
        duration := time.Since(start)
    
        fmt.Println(result, err, duration)
    }
    

    A blog post explains this principle in detail.