diff --git a/handlers_test.go b/handlers_test.go index e3f3650bb..23c0ee523 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -40,6 +40,96 @@ import ( margo "go.mozilla.org/mar" ) +type HandlerTestCase struct { + name string + method string + url string + + // urlRouteVars are https://pkg.go.dev/github.com/gorilla/mux#Vars + // as configured with the handler at /config/{keyid:[a-zA-Z0-9-_]{1,64}} + // there should only be a keyid var and it should match the url value + urlRouteVars map[string]string + + // headers are additional http headers to set + headers *http.Header + + // user/auth ID to build an Authorization header for + authorizeID string + nilBody bool + body string + + expectedStatus int + expectedHeaders http.Header + expectedBody string +} + +func (testcase *HandlerTestCase) NewRequest(t *testing.T) *http.Request { + // test request setup + var ( + req *http.Request + err error + ) + if testcase.nilBody { + req, err = http.NewRequest(testcase.method, testcase.url, nil) + } else { + req, err = http.NewRequest(testcase.method, testcase.url, strings.NewReader(testcase.body)) + } + if err != nil { + t.Fatal(err) + } + req = mux.SetURLVars(req, testcase.urlRouteVars) + + if testcase.headers != nil { + req.Header = *testcase.headers + } + if testcase.authorizeID != "" { + auth, err := ag.getAuthByID(testcase.authorizeID) + if err != nil { + t.Fatal(err) + } + // getAuthHeader requires a content type and body + req.Header.Set("Authorization", hawk.NewRequestAuth(req, + &hawk.Credentials{ + ID: auth.ID, + Key: auth.Key, + Hash: sha256.New}, + 0).RequestHeader()) + } + + return req +} + +func (testcase *HandlerTestCase) ValidateResponse(t *testing.T, w *httptest.ResponseRecorder) { + if w.Code != testcase.expectedStatus { + t.Fatalf("test case %s: got code %d but expected %d", + testcase.name, w.Code, testcase.expectedStatus) + } + if w.Body.String() != testcase.expectedBody { + t.Fatalf("test case %s: got body %q expected %q", testcase.name, w.Body.String(), testcase.expectedBody) + } + for expectedHeader, expectedHeaderVals := range testcase.expectedHeaders { + vals, ok := w.Header()[expectedHeader] + if !ok { + t.Fatalf("test case %s: expected header %q not found", testcase.name, expectedHeader) + } + if strings.Join(vals, "") != strings.Join(expectedHeaderVals, "") { + t.Fatalf("test case %s: header vals %q did not match expected %q ", testcase.name, vals, expectedHeaderVals) + } + } +} + +func (testcase *HandlerTestCase) Run(t *testing.T, handler func(http.ResponseWriter, *http.Request)) { + // test request setup + var req = testcase.NewRequest(t) + + // run the request + w := httptest.NewRecorder() + handler(w, req) + + // validate response + testcase.ValidateResponse(t, w) +} + func TestBadRequest(t *testing.T) { t.Parallel() @@ -271,38 +361,41 @@ func TestLBHeartbeat(t *testing.T) { } } -func checkHeartbeatReturnsExpectedStatusAndBody(t *testing.T, name, method string, expectedStatusCode int, expectedBody []byte) { - req, err := http.NewRequest(method, "http://foo.bar/__heartbeat__", nil) - if err != nil { - t.Fatal(err) - } - w := httptest.NewRecorder() - ag.handleHeartbeat(w, req) - if w.Code != expectedStatusCode { - t.Fatalf("test case %s failed with code %d but %d was expected", - name, w.Code, expectedStatusCode) - } - if !bytes.Equal(w.Body.Bytes(), expectedBody) { - t.Fatalf("test case %s returned unexpected heartbeat body %q expected %q", name, w.Body.Bytes(), expectedBody) - } -} - func TestHeartbeat(t *testing.T) { t.Parallel() - var TESTCASES = []struct { - name string - method string - expectedHTTPStatus int - expectedBody string - }{ - {"returns 200 for GET", `GET`, http.StatusOK, "{}"}, - {"returns 405 for POST", `POST`, http.StatusMethodNotAllowed, "POST method not allowed; endpoint accepts GET only\r\nrequest-id: -\n"}, - {"returns 405 for PUT", `PUT`, http.StatusMethodNotAllowed, "PUT method not allowed; endpoint accepts GET only\r\nrequest-id: -\n"}, - {"returns 405 for HEAD", `HEAD`, http.StatusMethodNotAllowed, "HEAD method not allowed; endpoint accepts GET only\r\nrequest-id: -\n"}, + var TESTCASES = []HandlerTestCase{ + { + name: "returns 200 for GET", + method: "GET", + url: "http://foo.bar/__heartbeat__", + expectedStatus: http.StatusOK, + expectedBody: "{}", + }, + { + name: "returns 405 for POST", + method: "POST", + url: "http://foo.bar/__heartbeat__", + expectedStatus: http.StatusMethodNotAllowed, + expectedBody: "POST method not allowed; endpoint accepts GET only\r\nrequest-id: -\n", + }, + { + name: "returns 405 for PUT", + method: "PUT", + url: "http://foo.bar/__heartbeat__", + expectedStatus: http.StatusMethodNotAllowed, + expectedBody: "PUT method not allowed; endpoint accepts GET only\r\nrequest-id: -\n", + }, + { + name: "returns 405 for HEAD", + method: "HEAD", + url: "http://foo.bar/__heartbeat__", + expectedStatus: http.StatusMethodNotAllowed, + expectedBody: "HEAD method not allowed; endpoint accepts GET only\r\nrequest-id: -\n", + }, } for _, testcase := range TESTCASES { - checkHeartbeatReturnsExpectedStatusAndBody(t, testcase.name, testcase.method, testcase.expectedHTTPStatus, []byte((testcase.expectedBody))) + testcase.Run(t, ag.handleHeartbeat) } } @@ -313,9 +406,14 @@ func TestHeartbeatChecksHSMStatusFails(t *testing.T) { hsmSignerConf: &ag.getSigners()[0].(*contentsignature.ContentSigner).Configuration, } - expectedStatus := http.StatusInternalServerError - expectedBody := []byte("{\"hsmAccessible\":false}") - checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 500 for GET with HSM inaccessible", `GET`, expectedStatus, expectedBody) + var testcase = HandlerTestCase{ + name: "returns 500 for GET with HSM inaccessible", + method: "GET", + url: "http://foo.bar/__heartbeat__", + expectedStatus: http.StatusInternalServerError, + expectedBody: "{\"hsmAccessible\":false}", + } + testcase.Run(t, ag.handleHeartbeat) ag.heartbeatConf = nil } @@ -324,9 +422,14 @@ func TestHeartbeatChecksHSMStatusFailsWhenNotConfigured(t *testing.T) { // NB: do not run in parallel with TestHeartbeat* ag.heartbeatConf = nil - expectedStatus := http.StatusInternalServerError - expectedBody := []byte("Missing heartbeat config\r\nrequest-id: -\n") - checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 500 for GET without heartbeat config HSM", `GET`, expectedStatus, expectedBody) + var testcase = HandlerTestCase{ + name: "returns 500 for GET without heartbeat config HSM", + method: "GET", + url: "http://foo.bar/__heartbeat__", + expectedStatus: http.StatusInternalServerError, + expectedBody: "Missing heartbeat config\r\nrequest-id: -\n", + } + testcase.Run(t, ag.handleHeartbeat) } func TestHeartbeatChecksDBStatusOKAndTimesout(t *testing.T) { @@ -346,25 +449,39 @@ func TestHeartbeatChecksDBStatusOKAndTimesout(t *testing.T) { DBCheckTimeout: 2 * time.Second, } - // check OK run locally requires running DB container - expectedStatus := http.StatusOK - expectedBody := []byte("{\"dbAccessible\":true}") - checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 200 for GET with DB accessible", `GET`, expectedStatus, expectedBody) + var dbAccessibleTestCase = HandlerTestCase{ + name: "returns 200 for GET with DB accessible", + method: "GET", + url: "http://foo.bar/__heartbeat__", + expectedStatus: http.StatusOK, + expectedBody: "{\"dbAccessible\":true}", + } + dbAccessibleTestCase.Run(t, ag.handleHeartbeat) // drop timeout ag.heartbeatConf.DBCheckTimeout = 1 * time.Nanosecond // check DB request times out - expectedStatus = http.StatusOK - expectedBody = []byte("{\"dbAccessible\":false}") - checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 200 for GET with DB time out", `GET`, expectedStatus, expectedBody) + var dbTimeoutTestCase = HandlerTestCase{ + name: "returns 200 for GET with DB time out", + method: "GET", + url: "http://foo.bar/__heartbeat__", + expectedStatus: http.StatusOK, + expectedBody: "{\"dbAccessible\":false}", + } + dbTimeoutTestCase.Run(t, ag.handleHeartbeat) // restore longer timeout and close the DB connection ag.heartbeatConf.DBCheckTimeout = 1 * time.Second db.Close() // check DB request still fails - expectedStatus = http.StatusOK - expectedBody = []byte("{\"dbAccessible\":false}") - checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 200 for GET with DB inaccessible", `GET`, expectedStatus, expectedBody) + var dbOfflineTestCase = HandlerTestCase{ + name: "returns 200 for GET with DB inaccessible", + method: "GET", + url: "http://foo.bar/__heartbeat__", + expectedStatus: http.StatusOK, + expectedBody: "{\"dbAccessible\":false}", + } + dbOfflineTestCase.Run(t, ag.handleHeartbeat) ag.db = nil } @@ -572,96 +689,6 @@ func TestDebug(t *testing.T) { } } -type HandlerTestCase struct { - name string - method string - url string - - // urlRouteVars are https://pkg.go.dev/github.com/gorilla/mux#Vars - // as configured with the handler at /config/{keyid:[a-zA-Z0-9-_]{1,64}} - // there should only be a keyid var and it should match the url value - urlRouteVars map[string]string - - // headers are additional http headers to set - headers *http.Header - - // user/auth ID to build an Authorization header for - authorizeID string - nilBody bool - body string - - expectedStatus int - expectedHeaders http.Header - expectedBody string -} - -func (testcase *HandlerTestCase) NewRequest(t *testing.T) *http.Request { - // test request setup - var ( - req *http.Request - err error - ) - if testcase.nilBody { - req, err = http.NewRequest(testcase.method, testcase.url, nil) - } else { - req, err = http.NewRequest(testcase.method, testcase.url, strings.NewReader(testcase.body)) - } - if err != nil { - t.Fatal(err) - } - req = mux.SetURLVars(req, testcase.urlRouteVars) - - if testcase.headers != nil { - req.Header = *testcase.headers - } - if testcase.authorizeID != "" { - auth, err := ag.getAuthByID(testcase.authorizeID) - if err != nil { - t.Fatal(err) - } - // getAuthHeader requires a content type and body - req.Header.Set("Authorization", hawk.NewRequestAuth(req, - &hawk.Credentials{ - ID: auth.ID, - Key: auth.Key, - Hash: sha256.New}, - 0).RequestHeader()) - } - - return req -} - -func (testcase* HandlerTestCase) ValidateResponse(t *testing.T, w *httptest.ResponseRecorder) { - if w.Code != testcase.expectedStatus { - t.Fatalf("test case %s: got code %d but expected %d", - testcase.name, w.Code, testcase.expectedStatus) - } - if w.Body.String() != testcase.expectedBody { - t.Fatalf("test case %s: got body %q expected %q", testcase.name, w.Body.String(), testcase.expectedBody) - } - for expectedHeader, expectedHeaderVals := range testcase.expectedHeaders { - vals, ok := w.Header()[expectedHeader] - if !ok { - t.Fatalf("test case %s: expected header %q not found", testcase.name, expectedHeader) - } - if strings.Join(vals, "") != strings.Join(expectedHeaderVals, "") { - t.Fatalf("test case %s: header vals %q did not match expected %q ", testcase.name, vals, expectedHeaderVals) - } - } -} - -func (testcase* HandlerTestCase) Run(t *testing.T, handler func(http.ResponseWriter, *http.Request)) { - // test request setup - var req = testcase.NewRequest(t) - - // run the request - w := httptest.NewRecorder() - handler(w, req) - - // validate response - testcase.ValidateResponse(t, w) -} - func TestHandleGetAuthKeyIDs(t *testing.T) { t.Parallel()