Search code examples
gowebsocketwaitgroup

Go: negative WaitGroup counter


I'm somewhat new to go and am reworking code that I found somewhere else to fit my needs. Because of that, I don't totally understand what is happening here, although I get the general idea.

I'm running a few websocket clients using go routines, but I'm getting an unexpected error that causes the program to crash. My program seems to close one too many threads (excuse me if this is the wrong terminology) when there is an error reading a message from the websocket (check the conn.ReadMessage() func in the readHandler func). Any ideas on how would I work around this issue? I would really appreciate anyone taking the time to look through it. Thanks in advance!

package main

import (
    "context"
    "fmt"
    "os"
    "time"
    "os/signal"
    "syscall"
    "sync"
    "net/url"
    "github.com/gorilla/websocket"
    "strconv"
    "encoding/json"
    "log"
    "bytes"
    "compress/gzip"
    "io/ioutil"
)

// Structs

type Ping struct {
    Ping    int64   `json:"ping"`
}

type Pong struct {
    Pong        int64       `json:"pong"`
}

type SubParams struct {
    Sub         string          `json:"sub"`
    ID          string          `json:"id"`
}

func InitSub(subType string, pair string, i int) []byte {
    var idInt string = "id" + strconv.Itoa(i)
    subStr := "market." + pair + "." + subType
    sub := &SubParams{
        Sub: subStr,
        ID: idInt,
    }

    out, err := json.MarshalIndent(sub, "", " ")
    if err != nil {
        log.Println(err);
    }
    //log.Println(string(out))
    return out
}

// main func

func main() {
    var server string = "api.huobi.pro"
    pairs := []string{"btcusdt", "ethusdt", "ltcusdt"}
    comms := make(chan os.Signal, 1)
    signal.Notify(comms, os.Interrupt, syscall.SIGTERM)

    ctx := context.Background()
    ctx, cancel := context.WithCancel(ctx)
    var wg sync.WaitGroup

    for x, pair := range pairs {
        wg.Add(1)
        go control(server, "ws", pair, ctx, &wg, x+1)
    }

    <-comms
    cancel()
    wg.Wait()
}

func control(server string, path string, pair string, ctx context.Context, wg *sync.WaitGroup, i int) {
    fmt.Printf("Started control for %s\n", server)
    url := url.URL {
        Scheme: "wss",
        Host: server,
        Path: path,
    }

    fmt.Println(url.String())

    conn, _, err := websocket.DefaultDialer.Dial(url.String(), nil)
    if err != nil {
        panic(err)
    }
    subscribe(conn, pair, i)
    defer conn.Close()

    var localwg sync.WaitGroup

    localwg.Add(1)
    go readHandler(ctx, conn, &localwg, server)

    <- ctx.Done()
    localwg.Wait()
    wg.Done()
    return
}

func readHandler(ctx context.Context, conn *websocket.Conn, wg *sync.WaitGroup, server string) {
    for {

        select {

            case <- ctx.Done():
                wg.Done()
                return
            default:
                _, p, err :=  conn.ReadMessage()
                if err != nil {
                    wg.Done()
                    fmt.Println(err)
                }
                r, err := gzip.NewReader(bytes.NewReader(p))
                if(err == nil) {
                    result, err := ioutil.ReadAll(r)
                    if(err != nil) {
                        fmt.Println(err)
                    }
                    d := string(result)
                    fmt.Println(d)

                    var ping Ping
                    json.Unmarshal([]byte(d), &ping)
                    if (ping.Ping > 0) {
                        str := Pong{Pong: ping.Ping}
                        msg, err := json.Marshal(str)
                        if (err == nil) {
                            fmt.Println(string(msg))
                            conn.WriteMessage(websocket.TextMessage, []byte(msg))
                        }
                    }
            }
        }
    }
}

func subscribe(conn *websocket.Conn, pair string, id int) {
    sub := string(InitSub("trade.detail", pair, id))

    err := conn.WriteMessage(websocket.TextMessage, []byte(sub))
    if err != nil {
        panic(err)
    }
}

Solution

    • Break out of the readHandler loop when the connection fails:

        _, p, err :=  conn.ReadMessage()
        if err != nil {
            wg.Done()
            fmt.Println(err)
            return // <--- add this line
        }
      

      Without the return, the function spins in a tight loop reading errors until the panic.

    • Use defer wg.Done() at the beginning of the goroutine to ensure that Done is called exactly once.

      func readHandler(ctx context.Context, conn *websocket.Conn, wg *sync.WaitGroup, server string) {
          defer wg.Done()
          for {
            select {
            case <-ctx.Done():
                return
            default:
                _, p, err := conn.ReadMessage()
                if err != nil {
                    fmt.Println(err)
                    return
                }
           ...
      

      Update the control function also.

    • Because the caller does not execute any code concurrently with readHander, there's no value in running readHandler is a goroutine. Remove all references to wait groups from readHandler and call the function directly: change go readHandler(ctx, conn, &localwg, server) to readHandler(ctx, conn, server).

    There are more issues, but this should move you further along.