Search code examples
goconcurrencywaitgroup

Check if all goroutines have finished without using wg.Wait()


Let's say I have a function IsAPrimaryColour() which works by calling three other functions IsRed(), IsGreen() and IsBlue(). Since the three functions are quite independent of one another, they can run concurrently. The return conditions are:

  1. If any of the three functions returns true, IsAPrimaryColour() should also return true. There is no need to wait for the other functions to finish. That is: IsPrimaryColour() is true if IsRed() is true OR IsGreen() is true OR IsBlue() is true
  2. If all functions return false, IsAPrimaryColour() should also return false. That is: IsPrimaryColour() is false if IsRed() is false AND IsGreen() is false AND IsBlue() is false
  3. If any of the three functions returns an error, IsAPrimaryColour() should also return the error. There is no need to wait for the other functions to finish, or to collect any other errors.

The thing I'm struggling with is how to exit the function if any other three functions return true, but also to wait for all three to finish if they all return false. If I use a sync.WaitGroup object, I will need to wait for all 3 go routines to finish before I can return from the calling function.

Therefore, I'm using a loop counter to keep track of how many times I have received a message on a channel and existing the program once I have received all 3 messages.

https://play.golang.org/p/kNfqWVq4Wix

package main

import (
    "errors"
    "fmt"
    "time"
)

func main() {
    x := "something"
    result, err := IsAPrimaryColour(x)

    if err != nil {
        fmt.Printf("Error: %v\n", err)
    } else {
        fmt.Printf("Result: %v\n", result)
    }
}

func IsAPrimaryColour(value interface{}) (bool, error) {
    found := make(chan bool, 3)
    errors := make(chan error, 3)
    defer close(found)
    defer close(errors)
    var nsec int64 = time.Now().UnixNano()

    //call the first function, return the result on the 'found' channel and any errors on the 'errors' channel
    go func() {
        result, err := IsRed(value)
        if err != nil {
            errors <- err
        } else {
            found <- result
        }
        fmt.Printf("IsRed done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
    }()

    //call the second function, return the result on the 'found' channel and any errors on the 'errors' channel
    go func() {
        result, err := IsGreen(value)
        if err != nil {
            errors <- err
        } else {
            found <- result
        }
        fmt.Printf("IsGreen done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
    }()

    //call the third function, return the result on the 'found' channel and any errors on the 'errors' channel
    go func() {
        result, err := IsBlue(value)
        if err != nil {
            errors <- err
        } else {
            found <- result
        }
        fmt.Printf("IsBlue done in %f nanoseconds \n", float64(time.Now().UnixNano()-nsec))
    }()

    //loop counter which will be incremented every time we read a value from the 'found' channel
    var counter int

    for {
        select {
        case result := <-found:
            counter++
            fmt.Printf("received a value on the results channel after %f nanoseconds. Value of counter is %d\n", float64(time.Now().UnixNano()-nsec), counter)
            if result {
                fmt.Printf("some goroutine returned true\n")
                return true, nil
            }
        case err := <-errors:
            if err != nil {
                fmt.Printf("some goroutine returned an error\n")
                return false, err
            }
        default:
        }

        //check if we have received all 3 messages on the 'found' channel. If so, all 3 functions must have returned false and we can thus return false also
        if counter == 3 {
            fmt.Printf("all goroutines have finished and none of them returned true\n")
            return false, nil
        }
    }
}

func IsRed(value interface{}) (bool, error) {
    return false, nil
}

func IsGreen(value interface{}) (bool, error) {
    time.Sleep(time.Millisecond * 100) //change this to a value greater than 200 to make this function take longer than IsBlue()
    return true, nil
}

func IsBlue(value interface{}) (bool, error) {
    time.Sleep(time.Millisecond * 200)
    return false, errors.New("something went wrong")
}

Although this works well enough, I wonder if I'm not overlooking some language feature to do this in a better way?


Solution

  • You don't have to block the main goroutine on the Wait, you could block something else, for example:

    doneCh := make(chan struct{}{})
    
    go func() {
        wg.Wait()
        close(doneCh)
    }()
    

    Then you can wait on doneCh in your select to see if all the routines have finished.