Search code examples
goconcurrencygoroutine

Deferred call to sync.WaitGroup.Wait() in Goroutine: why should this work?


I'm trying to understand the Attack() function (https://github.com/tsenart/vegeta/blob/44a49c878dd6f28f04b9b5ce5751490b0dce1e18/lib/attack.go#L253-L312) in the source code of the vegeta load testing tool/library. I've created a simplified example:

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go attack(&wg)
    }
    // wg.Wait()

    go func() {
        defer wg.Wait()
    }()
}

func attack(wg *sync.WaitGroup) {
    defer wg.Done()
    time.Sleep(1 * time.Second)
    fmt.Println("foobar")
}

What I notice is that this function returns immediately without printing foobar 10 times. It is only if comment in the line wg.Wait() that I see foobar printed 10 times after 1 second. This makes sense to me, because the main() function returns before wg.Wait() is called.

What I don't understand, then, is how the Attack() method works in vegeta, because it seems to follow a similar pattern:

func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) <-chan *Result {
    var wg sync.WaitGroup

    workers := a.workers
    if workers > a.maxWorkers {
        workers = a.maxWorkers
    }

    results := make(chan *Result)
    ticks := make(chan struct{})
    for i := uint64(0); i < workers; i++ {
        wg.Add(1)
        go a.attack(tr, name, &wg, ticks, results)
    }

    go func() {
        defer close(results)
        defer wg.Wait()
        defer close(ticks)

        began, count := time.Now(), uint64(0)
        for {
            elapsed := time.Since(began)
            if du > 0 && elapsed > du {
                return
            }

            wait, stop := p.Pace(elapsed, count)
            if stop {
                return
            }

            time.Sleep(wait)

            if workers < a.maxWorkers {
                select {
                case ticks <- struct{}{}:
                    count++
                    continue
                case <-a.stopch:
                    return
                default:
                    // all workers are blocked. start one more and try again
                    workers++
                    wg.Add(1)
                    go a.attack(tr, name, &wg, ticks, results)
                }
            }

            select {
            case ticks <- struct{}{}:
                count++
            case <-a.stopch:
                return
            }
        }
    }()

    return results
}

where the attack() method reads

func (a *Attacker) attack(tr Targeter, name string, workers *sync.WaitGroup, ticks <-chan struct{}, results chan<- *Result) {
    defer workers.Done()
    for range ticks {
        results <- a.hit(tr, name)
    }
}

I don't understand why the Attack() function doesn't return immediately without invoking attack(), since its wg.Wait() is inside a Goroutine?


Solution

  • vegeta's Attack also immediately returns, but with a channel which is populated by the goroutines that remain running. Once those finish the channel is closed (defer close(results)) enabling code that has result to detect completion.

    Example;

    package main
    
    import (
        "fmt"
        "sync"
        "time"
    )
    
    func main() {
        results := attacks()
    
        fmt.Println("attacks returned")
    
        for result := range results {
            fmt.Println(result)
        }
    }
    
    func attacks() chan string {
        // A channel to hold the results
        c := make(chan string)
    
        // Fire 10 routines populating the channel
        var wg sync.WaitGroup
        for i := 0; i < 10; i++ {
            wg.Add(1)
            go func() {
                attack(c)
                wg.Done()
            }()
        }
    
        // Close channel once routines are finished
        go func() {
            wg.Wait()
            close(c)
        }()
    
        //
        return c
    }
    
    func attack(c chan<- string) {
        time.Sleep(1 * time.Second)
        c <- "foobar"
    }