I was watching the Rob Pike's fascinating talk called "Concurrency is not Parallelism" and I was trying to extend the example with concurrent queries, but it seems that I don't fully understand how goroutines work. In particular, I don't understand how to cancel the goroutine during its execution.
Here is a dummy connection that can perform a query (suppose it is unchangeable, like an external API):
type Conn struct {
id int
delay time.Duration
}
type Result int
func (c *Conn) DoQuery(query string) Result {
defer fmt.Printf("DoQuery call for connection %d finished\n", c.id)
time.Sleep(c.delay)
return Result(len(query))
}
Below is the Query
function that sends a query to each connection. When it gets the first result, it supposed to cancel all the other calls (as a teardown), but it does not. When the goroutine starts the query, it looses the ability to watch the ctx.Done()
channel, because the conn.DoQuery(...)
call is blocking.
func Query(conns []Conn, query string) Result {
results := make(chan Result, len(conns))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, conn := range conns {
go func() {
defer fmt.Printf("goroutine for connection %d finished\n", conn.id)
select {
case <-ctx.Done():
return
default:
// While the goroutine computes the result of this call
// it does not watch the ctx.Done() channel, so the call
// cannot be canceled mid-execution.
results <- conn.DoQuery(query)
}
}()
}
return <-results
}
Here is also the main
function that ties everything together (in case you want to run the example):
func main() {
conns := []Conn{
{id: 1, delay: 1 * time.Second},
{id: 2, delay: 3 * time.Second},
{id: 3, delay: 5 * time.Second},
{id: 4, delay: 4 * time.Second},
}
start := time.Now()
result := Query(conns, "select count(*) from users;")
duration := time.Since(start)
fmt.Println(result, duration)
for {}
}
So my question is: what is the correct way to implement the Query
function to achieve the desired behavior?
Your Query
function has to support cancellation (Go Playground):
func (c *Conn) DoQuery(ctx context.Context, query string) (Result, error) {
defer fmt.Printf("DoQuery call for connection %d finished\n", c.id)
t := time.NewTimer(c.delay)
defer t.Stop()
select {
case <-t.C:
return Result(c.id), nil
case <-ctx.Done():
return Result(0), ctx.Err()
}
}
That would be the case for a network call or a database query.
Then, you shouldn't leak goroutines:
func Query(ctx context.Context, conns []Conn, query string) (Result, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var first atomic.Bool
var result Result
var err error
setFirst := func(r Result, e error) {
if first.CompareAndSwap(false, true) {
result, err = r, e
cancel()
}
}
var wg sync.WaitGroup
for _, conn := range conns {
wg.Add(1)
go func() {
defer wg.Done()
defer fmt.Printf("goroutine for connection %d finished\n", conn.id)
setFirst(conn.DoQuery(ctx, query))
}()
}
wg.Wait()
return result, err
}
I'm using an atomic.Bool
here for synchronization, since you are only interested in the first result. Feel free to use a mutex or whatever fits to your problem.
You would call this with a context too, derived from the root context:
func main() {
ctx := context.Background()
conns := []Conn{
{id: 1, delay: 1 * time.Second},
{id: 2, delay: 3 * time.Second},
{id: 3, delay: 5 * time.Second},
{id: 4, delay: 4 * time.Second},
}
start := time.Now()
result, err := Query(ctx, conns, "select count(*) from users;")
duration := time.Since(start)
fmt.Println(result, err, duration)
}
A blog post explains this principle in detail.