Search code examples
testinggomiddleware

Testing golang middleware that modifies the request


I have some middleware that adds a context with a request id to a request.

func AddContextWithRequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    var ctx context.Context
    ctx = NewContextWithRequestID(ctx, r)
    next.ServeHTTP(w, r.WithContext(ctx))
})}

How do I write a test for this ?


Solution

  • To test that, you need to run that handler passing in a request, and using a custom next handler that checks that the request was indeed modified.

    You can create that handler as follows:

    (I am assuming your NewContextWithRequestID adds a "reqId" key to the request with a "1234" value, you should of course modify the assertions as needed)

    // create a handler to use as "next" which will verify the request
    nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        val := r.Context().Value("reqId")
        if val == nil {
            t.Error("reqId not present")
        }
        valStr, ok := val.(string)
        if !ok {
            t.Error("not string")
        }
        if valStr != "1234" {
            t.Error("wrong reqId")
        }
    })
    

    You can then use that handler as your next one:

    // create the handler to test, using our custom "next" handler
    handlerToTest := AddContextWithRequestID(nextHandler)
    

    And then invoke that handler:

    // create a mock request to use
    req := httptest.NewRequest("GET", "http://testing", nil)
    // call the handler using a mock response recorder (we'll not use that anyway)
    handlerToTest.ServeHTTP(httptest.NewRecorder(), req)
    

    Putting everything together as a working test, that'd be the code below.

    Note: I fixed a small bug in your original "AddContextWithRequestID", as the ctx value started with a nil value when you just declared it with no initialization.

    import (
        "net/http"
        "context"
        "testing"
        "net/http/httptest"
    )
    
    func NewContextWithRequestID(ctx context.Context, r *http.Request) context.Context {
        return context.WithValue(ctx, "reqId", "1234")
    }
    
    func AddContextWithRequestID(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            var ctx = context.Background()
            ctx = NewContextWithRequestID(ctx, r)
            next.ServeHTTP(w, r.WithContext(ctx))
        })
    }
    
    func TestIt(t *testing.T) {
    
        // create a handler to use as "next" which will verify the request
        nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            val := r.Context().Value("reqId")
            if val == nil {
                t.Error("reqId not present")
            }
            valStr, ok := val.(string)
            if !ok {
                t.Error("not string")
            }
            if valStr != "1234" {
                t.Error("wrong reqId")
            }
        })
    
        // create the handler to test, using our custom "next" handler
        handlerToTest := AddContextWithRequestID(nextHandler)
    
        // create a mock request to use
        req := httptest.NewRequest("GET", "http://testing", nil)
    
        // call the handler using a mock response recorder (we'll not use that anyway)
        handlerToTest.ServeHTTP(httptest.NewRecorder(), req)
    }