Search code examples
multithreadinggopipelinegoroutine

How to fix goroutine leaks while pipeline cancellation


I build the pipline with thread-pool functions and pass context.Context in it as argument. When cancel() function called, or timeout expired the pipeline must terminate gracefully so that there are no working goroutines left.

functions I work with:

func generate(amount int) <-chan int {
    result := make(chan int)
    go func() {
        defer close(result)
        for i := 0; i < amount; i++ {
            result <- i
        }
    }()

    return result
}

func sum(input <-chan int) int {
    result := 0
    for el := range input {
        result += el
    }
    return result
}
func process[T any, R any](ctx context.Context, workers int, input <-chan T, do func(T) R) <-chan R {
    wg := new(sync.WaitGroup)
    result := make(chan R)

    for i := 0; i < workers; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            for {
                select {
                case <-ctx.Done():
                    return
                case val, ok := <-input:
                    if !ok {
                        return
                    }
                    select {
                    case <-ctx.Done():
                        return
                    case result <- do(val):
                    }
                }
            }
        }()
    }

    go func() {
        defer close(result)
        wg.Wait()
    }()
    return result
}

Usage:

func main() {
    ctx, cancel := context.WithTimeout(context.Background(), 1200*time.Millisecond)
    defer cancel()

    input := generate(1000)
    multiplied := process(ctx, 15, input, func(val int) int {
        time.Sleep(time.Second)
        return val * 2
    })
    increased := process(ctx, 15, multiplied, func(val int) int {
        return val + 10
    })

    fmt.Println("Result: ", sum(increased)) //  360 is ok
    fmt.Println("Num goroutine: ", runtime.NumGoroutine())  // 18 is too much
}

I understand that this happened because all the increase goroutines ended, while the multiply goroutines were still running.

Is there any canonical way to solve this problem?


Solution

  • You expecting something like structured concurrency, so all goroutines should end at the end of the current scope, but do not design your code according to your expectations. You'll leak generate when the input channel is not depleted and your do functions are not cancellable.

    Adding cancelability to generate and your do functions helps a little:

    package main
    
    import (
        "context"
        "fmt"
        "runtime"
        "sync"
        "time"
    )
    
    func main() {
        ctx, cancel := context.WithTimeout(context.Background(), 1200*time.Millisecond)
        defer cancel()
    
        input := generate(ctx, 1_000)
    
        multiplied := process(ctx, 15, input, func(ctx context.Context, val int) (int, error) {
            select {
            case <-ctx.Done():
                return 0, ctx.Err()
    
            case <-time.After(time.Second):
                return val * 2, nil
            }
        })
    
        increased := process(ctx, 15, multiplied, func(_ context.Context, val int) (int, error) {
            return val + 10, nil
        })
    
        fmt.Println("Result: ", sum(increased))                //  360 is ok
        fmt.Println("Num goroutine: ", runtime.NumGoroutine()) // 18 is too much
    }
    
    func generate(ctx context.Context, amount int) <-chan int {
        input := make(chan int)
        go func() {
            defer close(input)
            for i := 0; i < amount; i++ {
                select {
                case <-ctx.Done():
                    return
    
                case input <- i:
                }
            }
        }()
    
        return input
    }
    
    func sum(input <-chan int) int {
        result := 0
        for el := range input {
            result += el
        }
        return result
    }
    
    func process[T any, R any](ctx context.Context, workers int, input <-chan T, do func(context.Context, T) (R, error)) <-chan R {
        wg := new(sync.WaitGroup)
        result := make(chan R)
    
        for i := 0; i < workers; i++ {
            wg.Add(1)
            go func() {
                defer wg.Done()
                for {
                    select {
                    case <-ctx.Done():
                        return
                    case val, ok := <-input:
                        if !ok {
                            return
                        }
                        r, err := do(ctx, val)
                        if err != nil {
                            return
                        }
                        result <- r
                    }
                }
            }()
        }
    
        go func() {
            defer close(result)
            wg.Wait()
        }()
    
        return result
    }
    

    More is mentioned in “Advanced Go Concurrency Patterns”, but as a general recommendation I would advise to write synchronous code first when you aim for structured concurrency and later work to run them concurrently it.