Search code examples
goconcurrencyredisrace-condition

Processing a Redis queue using BLPOP causes a race condition in unit tests?


I'm trying to implement a first-in, first-out task queue as described in Chapter 6.4.1 of the Redis e-book in Go. For testing purposes, I'm passing in a CommandExecutor interface to the 'worker' function like so:

package service

import (
    "context"

    "github.com/gomodule/redigo/redis"
    "github.com/pkg/errors"
    "github.com/sirupsen/logrus"
)

const commandsQueue = "queuedCommands:"

var pool = redis.Pool{
    MaxIdle:   50,
    MaxActive: 1000,
    Dial: func() (redis.Conn, error) {
        conn, err := redis.Dial("tcp", ":6379")
        if err != nil {
            logrus.WithError(err).Fatal("initialize Redis pool")
        }
        return conn, err
    },
}

// CommandExecutor executes a command
type CommandExecutor interface {
    Execute(string) error
}

func processQueue(ctx context.Context, done chan<- struct{}, executor CommandExecutor) error {
    rc := pool.Get()
    defer rc.Close()

    for {
        select {
        case <-ctx.Done():
            done <- struct{}{}
            return nil
        default:
            // If the commands queue does not exist, BLPOP blocks until another client
            // performs an LPUSH or RPUSH against it. The timeout argument of zero is
            // used to block indefinitely.
            reply, err := redis.Strings(rc.Do("BLPOP", commandsQueue, 0))
            if err != nil {
                logrus.WithError(err).Errorf("BLPOP %s %d", commandsQueue, 0)
                return errors.Wrapf(err, "BLPOP %s %d", commandsQueue, 0)
            }

            if len(reply) < 2 {
                logrus.Errorf("Expected a reply of length 2, got one of length %d", len(reply))
                return errors.Errorf("Expected a reply of length 2, got one of length %d", len(reply))
            }

            // BLPOP returns a two-element multi-bulk with the first element being the
            // name of the key where an element was popped and the second element
            // being the value of the popped element (cf. https://redis.io/commands/blpop#return-value)
            if err := executor.Execute(reply[1]); err != nil {
                return errors.Wrapf(err, "execute scheduled command: %s", reply[0])
            }
            done <- struct{}{}
        }
    }
}

I've made a small example repository, https://github.com/kurtpeek/process-queue, with this code as well as an attempt at unit tests. For the unit test, I have two tests which are the same (with different names):

package service

import (
    "context"
    "testing"

    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
)

func TestProcessQueue(t *testing.T) {
    ctx, cancel := context.WithCancel(context.Background())
    defer cancel()

    executor := &CommandExecutorMock{
        ExecuteFunc: func(string) error {
            return nil
        },
    }

    done := make(chan struct{})
    go processQueue(ctx, done, executor)

    rc := pool.Get()
    defer rc.Close()

    _, err := rc.Do("RPUSH", commandsQueue, "foobar")
    require.NoError(t, err)

    <-done

    assert.Exactly(t, 1, len(executor.ExecuteCalls()))
    assert.Exactly(t, "foobar", executor.ExecuteCalls()[0].In1)
}

func TestProcessQueue2(t *testing.T) {
    ctx, cancel := context.WithCancel(context.Background())
    defer cancel()

    executor := &CommandExecutorMock{
        ExecuteFunc: func(string) error {
            return nil
        },
    }

    done := make(chan struct{})
    go processQueue(ctx, done, executor)

    rc := pool.Get()
    defer rc.Close()

    _, err := rc.Do("RPUSH", commandsQueue, "foobar")
    require.NoError(t, err)

    <-done

    assert.Exactly(t, 1, len(executor.ExecuteCalls()))
    assert.Exactly(t, "foobar", executor.ExecuteCalls()[0].In1)
}

where the CommandExecutorMock is generated using moq. If I run each test individually, they pass:

~/g/s/g/k/process-queue> go test ./... -v -run TestProcessQueue2
=== RUN   TestProcessQueue2
--- PASS: TestProcessQueue2 (0.00s)
PASS
ok      github.com/kurtpeek/process-queue/service   0.243s

However, if I run all the tests, the second one times out:

~/g/s/g/k/process-queue> 
go test ./... -v -timeout 10s
=== RUN   TestProcessQueue
--- PASS: TestProcessQueue (0.00s)
=== RUN   TestProcessQueue2
panic: test timed out after 10s

It seems that when the second test runs, the goroutine started in the first test is still running and BLPOPing the command from the queue, so that the <-done line in the second test blocks indefinitely. This is despite calling cancel() on the parent context of the first test.

How can I 'isolate' these tests so that they both pass when run together? (I've tried passing the -p 1 flag to go test but to no avail).


Solution

  • This is despite calling cancel() on the parent context of the first test.

    There is some time between writing to done and calling cancel(), which means that the first test might (and does) enter the second for/select iteration instead of exiting on <-ctx.Done(). More specifically, the test code includes 2 assertions before the cancellation:

        assert.Exactly(t, 1, len(executor.ExecuteCalls()))
        assert.Exactly(t, "foobar", executor.ExecuteCalls()[0].In1)
    

    Only then defer cancel() kicks in, which appears to be too late to cancel the context on the first go routine.

    If you move cancel() call just before reading from done, the tests pass:

    func TestProcessQueue(t *testing.T) {
        ctx, cancel := context.WithCancel(context.Background())
    
        executor := &CommandExecutorMock{
            ExecuteFunc: func(string) error {
                return nil
            },
        }
    
        done := make(chan struct{})
        go processQueue(ctx, done, executor)
    
        rc := pool.Get()
        defer rc.Close()
    
        _, err := rc.Do("RPUSH", commandsQueue, "foobar")
        require.NoError(t, err)
    
        cancel() // note this change right here
        <-done
    
        assert.Exactly(t, 1, len(executor.ExecuteCalls()))
        assert.Exactly(t, "foobar", executor.ExecuteCalls()[0].In1)
    }