Search code examples
goworkerrate-limiting

Worker pool with buffered jobs and fixed polling interval


I have a worker pool listening on a jobs channel, and responding on a results channel.

The jobs producer must run on a fixed ticker interval. Results must be flushed before reading just enough new jobs to fill up the buffer. It's critical to flush results, and read new jobs, in batches.

See example code below, run it on the playground here.

Is it possible to rewrite this without an atomic counter for keeping track of inflight jobs?

// Worker pool with buffered jobs and fixed polling interval

package main

import (
    "fmt"
    "math/rand"
    "os"
    "os/signal"
    "strings"
    "sync"
    "sync/atomic"
    "syscall"
    "time"
)

func main() {
    rand.Seed(time.Now().UnixNano())

    // buf is the size of the jobs buffer
    buf := 5

    // workers is the number of workers to start
    workers := 3

    // jobs chan for workers
    jobs := make(chan int, buf)
    // results chan for workers
    results := make(chan int, buf*2)

    // jobID is incremented for each job sent on the jobs chan
    var jobID int

    // inflight is a count of the items in the jobs chan buffer
    var inflight uint64

    // pollInterval for jobs producer
    pollInterval := 500 * time.Millisecond

    // pollDone chan to stop polling
    pollDone := make(chan bool)

    // jobMultiplier on pollInterval for random job processing times
    jobMultiplier := 5

    // done chan to exit program
    done := make(chan bool)

    // Start workers
    wg := sync.WaitGroup{}
    for n := 0; n < workers; n++ {
        wg.Add(1)
        go (func(n int) {
            defer wg.Done()
            for {
                // Receive from channel or block
                jobID, more := <-jobs
                if more {
                    // To subtract a signed positive constant value...
                    // https://golang.org/pkg/sync/atomic/#AddUint64
                    c := atomic.AddUint64(&inflight, ^uint64(0))
                    fmt.Println(
                        fmt.Sprintf("worker %v processing %v - %v jobs left",
                            n, jobID, c))
                    // Processing the job...
                    m := rand.Intn(jobMultiplier)
                    time.Sleep(time.Duration(m) * pollInterval)
                    results <- jobID
                } else {
                    fmt.Println(fmt.Sprintf("worker %v exited", n))
                    return
                }
            }
        })(n)
    }

    // Signal to exit
    sig := make(chan os.Signal, 1)
    signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
    fmt.Println("ctrl+c to exit")

    go (func() {
        ticker := time.NewTicker(pollInterval)
        r := make([]string, 0)
        flushResults := func() {
            fmt.Println(
                fmt.Sprintf("===> results: %v", strings.Join(r, ",")))
            r = make([]string, 0)
        }

        for {
            select {
            case <-ticker.C:
                flushResults()

                // Fetch jobs
                c := atomic.LoadUint64(&inflight)
                d := uint64(buf) - c
                for i := 0; i < int(d); i++ {
                    jobID++
                    jobs <- jobID
                    atomic.AddUint64(&inflight, 1)
                }
                fmt.Println(fmt.Sprintf("===> send %v jobs", d))

            case jobID := <-results:
                r = append(r, fmt.Sprintf("%v", jobID))

            case <-pollDone:
                // Stop polling for new jobs
                ticker.Stop()

                // Close jobs channel to stop workers
                close(jobs)

                // Wait for workers to exit
                wg.Wait()
                close(results)

                // Flush remaining results
                for {
                    jobID, more := <-results
                    if more {
                        r = append(r, fmt.Sprintf("%v", jobID))
                    } else {
                        break
                    }
                }
                flushResults()

                // Done!
                done <- true
                return
            }
        }
    })()

    // Wait for exit signal
    <-sig

    fmt.Println("---------| EXIT |---------")
    pollDone <- true
    <-done
    fmt.Println("...done")
}


Solution

  • Here is a channel-based version of your code, functionally equivalent to the intent of the example above. The key points is that we're not using any atomic values to vary the logic of the code, because that offers no synchronization between the goroutines. All interactions between the goroutines are synchronized using channels, sync.WaitGroup, or context.Context. There are probably better ways to solve the problem at hand, but this demonstrates that there are no atomics necessary to coordinate the queue and workers.

    The only value that is still left uncoordinated between goroutines here is the use of len(jobs) in the log output. Whether it makes sense to use it or not is up to you, as its value is meaningless in a concurrent world, but it's safe because it's synchronized for concurrent use and there is no logic based on the value.

    buf := 5
    workers := 3
    jobs := make(chan int, buf)
    
    // results buffer must always be larger than workers + buf to prevent deadlock
    results := make(chan int, buf*2)
    
    ctx, cancel := context.WithCancel(context.Background())
    defer cancel()
    
    // Start workers
    var wg sync.WaitGroup
    for n := 0; n < workers; n++ {
        wg.Add(1)
        go func(n int) {
            defer wg.Done()
            for jobID := range jobs {
                fmt.Printf("worker %v processing %v - %v jobs left\n", n, jobID, len(jobs))
                time.Sleep(time.Duration(rand.Intn(5)) * pollInterval)
                results <- jobID
            }
            fmt.Printf("worker %v exited", n)
        }(n)
    }
    
    var done sync.WaitGroup
    done.Add(1)
    go func() {
        defer done.Done()
        ticker := time.NewTicker(pollInterval)
        r := make([]string, 0)
    
        flushResults := func() {
            fmt.Printf("===> results: %v\n", strings.Join(r, ","))
            r = r[:0]
        }
    
        for {
            select {
            case <-ticker.C:
                flushResults()
    
                // send max buf jobs, or fill the queue
                for i := 0; i < buf; i++ {
                    jobID++
                    select {
                    case jobs <- jobID:
                        continue
                    }
                    break
                }
                fmt.Printf("===> send %v jobs\n", i)
    
            case jobID := <-results:
                r = append(r, fmt.Sprintf("%v", jobID))
    
            case <-ctx.Done():
                // Close jobs channel to stop workers
                close(jobs)
                // Wait for workers to exit
                wg.Wait()
    
                // we can close results for easy iteration because we know
                // there are no more workers.
                close(results)
                // Flush remaining results
                for jobID := range results {
                    r = append(r, fmt.Sprintf("%v", jobID))
                }
                flushResults()
                return
            }
        }
    }()