Search code examples
goretrypolicy

Turn Retry Policy into Reusable Function


We have a simple retry policy for our project:

  1. On the first error, sleep for 1 second.
  2. On the second error, sleep for 5 seconds.
  3. On the third error, sleep for 10 seconds.
  4. On the fourth error, quit retrying and return the error.

Here is what our retry policy looks like:

package main

import (
    "errors"  
    "fmt"         
    "time"    
)

func main() {
    errorCount := 0
    var err error

    fmt.Println("start!")

    for {
        err = generateError()
        if err != nil {
            if errorCount == 0 {
                fmt.Println("sleeping for 1 second...")
                time.Sleep(1 * time.Second)    
            } else if errorCount == 1 {
                fmt.Println("sleeping for 5 seconds...")
                time.Sleep(5 * time.Second)    
            } else if errorCount == 2 {
                fmt.Println("sleeping for 10 seconds...")
                time.Sleep(10 * time.Second)    
            } else {
                fmt.Println("giving up...")
                break
            }

            errorCount++
        } else {
            fmt.Println("no errors!")
            break
        }
    }

    fmt.Println("error:", err)
    fmt.Println("done!")
}

func generateError() error {
    err := errors.New("something happened")
    return err
}

Is there a way to turn the above code into a reusable function?


Solution

  • Simply pass generateError as an argument (I simplified the retry function because I couldn't help myself):

    package main
    
    import (
        "errors"
        "fmt"
        "time"
    )
    
    func main() {
        retry(generateError)
    }
    
    func retry(f func() error) {
        fmt.Println("start!")
    
        backoff := []time.Duration{
            1 * time.Second,
            5 * time.Second,
            10 * time.Second,
        }
    
        var err error
        for _, d := range backoff {
            err = f()
            if err != nil {
                fmt.Printf("sleeping for %v...\n", d)
                time.Sleep(d)
            } else {
                fmt.Println("no errors!")
                return
            }
        }
    
        fmt.Println("error:", err)
        fmt.Println("done!")
    }
    
    func generateError() error {
        err := errors.New("something happened")
        return err
    }