Search code examples
gogoroutine

Unable to use goroutines concurrently to find max until context is cancelled


I have successfully made a synchronous solution without goroutines to findMax of compute calls.

package main

import (
    "context"
    "fmt"
    "math/rand"
    "time"
)

func findMax(ctx context.Context, concurrency int) uint64 {
    var (
        max uint64 = 0
        num uint64 = 0
    )

    for i := 0; i < concurrency; i++ {
        num = compute()

        if num > max {
            max = num
        }
    }

    return max
}

func compute() uint64 {
    // NOTE: This is a MOCK implementation of the blocking operation.
    
    time.Sleep(time.Duration(rand.Int63n(100)) * time.Millisecond)
    return rand.Uint64()
}

func main() {
    maxDuration := 2 * time.Second
    concurrency := 10

    ctx, cancel := context.WithTimeout(context.Background(), maxDuration)
    defer cancel()

    max := findMax(ctx, concurrency)
    fmt.Println(max)
}


https://play.golang.org/p/lYXRNTDtNCI

When I attempt to use goroutines to use findMax to repeatedly call compute function using as many goroutines until context ctx is canceled by the caller main function. I am getting 0 every time and not the expected max of the grouting compute function calls. I have tried different ways to do it and get deadlock most of the time.

package main

import (
    "context"
    "fmt"
    "math/rand"
    "time"
)

func findMax(ctx context.Context, concurrency int) uint64 {
    var (
        max uint64 = 0
        num uint64 = 0
    )

    for i := 0; i < concurrency; i++ {
        select {
        case <- ctx.Done():
            return max
        default:
            go func() {
                num = compute()
                if num > max {
                    max = num
                }
            }()
        }
    }

    return max
}

func compute() uint64 {
    // NOTE: This is a MOCK implementation of the blocking operation.
    
    time.Sleep(time.Duration(rand.Int63n(100)) * time.Millisecond)
    return rand.Uint64()
}

func main() {
    maxDuration := 2 * time.Second
    concurrency := 10

    ctx, cancel := context.WithTimeout(context.Background(), maxDuration)
    defer cancel()

    max := findMax(ctx, concurrency)
    fmt.Println(max)
}

https://play.golang.org/p/3fFFq2xlXAE


Solution

  • Your program has multiple problems:

    1. You are spawning multiple goroutines that are operating on shared variables i.e., max and num leading to data race as they are not protected (eg. by Mutex).
    2. Here num is modified by every worker goroutine but it should have been local to the worker otherwise the computed data could be lost (eg. one worker goroutine computed a result and stored it in num, but right after that a second worker computes and replaces the value of num).
     num = compute // Should be "num := compute"
    
    1. You are not waiting for every goroutine to finish it's computation and it may result in incorrect results as every workers computation wasn't taken into account even though context wasn't cancelled. Use sync.WaitGroup or channels to fix this.

    Here's a sample program that addresses most of the issues in your code:

    package main
    
    import (
        "context"
        "fmt"
        "math/rand"
        "sync"
        "time"
    )
    
    type result struct {
        sync.RWMutex
        max uint64
    }
    
    func findMax(ctx context.Context, workers int) uint64 {
        var (
            res = result{}
            wg  = sync.WaitGroup{}
        )
    
        for i := 0; i < workers; i++ {
            select {
            case <-ctx.Done():
                // RLock to read res.max
                res.RLock()
                ret := res.max
                res.RUnlock()
                return ret
            default:
                wg.Add(1)
                go func() {
                    defer wg.Done()
                    num := compute()
    
                    // Lock so that read from res.max and write
                    // to res.max is safe. Else, data race could
                    // occur.
                    res.Lock()
                    if num > res.max {
                        res.max = num
                    }
                    res.Unlock()
                }()
            }
        }
    
        // Wait for all the goroutine to finish work i.e., all
        // workers are done computing and updating the max.
        wg.Wait()
    
        return res.max
    }
    
    func compute() uint64 {
        rnd := rand.Int63n(100)
        time.Sleep(time.Duration(rnd) * time.Millisecond)
        return rand.Uint64()
    }
    
    func main() {
        maxDuration := 2 * time.Second
        concurrency := 10
    
        ctx, cancel := context.WithTimeout(context.Background(), maxDuration)
        defer cancel()
    
        fmt.Println(findMax(ctx, concurrency))
    }
    

    As @Brits pointed out in the comments that when context is cancelled make sure that you stop those worker goroutines to stop processing (if possible) because it is not needed anymore.