Skip to content

Commit

Permalink
Use the common helpers for hearbeat testing too
Browse files Browse the repository at this point in the history
  • Loading branch information
oskirby committed Jun 13, 2024
1 parent dc7e3d2 commit 4e4af28
Showing 1 changed file with 160 additions and 133 deletions.
293 changes: 160 additions & 133 deletions handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,96 @@ import (
margo "go.mozilla.org/mar"
)

type HandlerTestCase struct {
name string
method string
url string

// urlRouteVars are https://pkg.go.dev/github.com/gorilla/mux#Vars
// as configured with the handler at /config/{keyid:[a-zA-Z0-9-_]{1,64}}
// there should only be a keyid var and it should match the url value
urlRouteVars map[string]string

// headers are additional http headers to set
headers *http.Header

// user/auth ID to build an Authorization header for
authorizeID string
nilBody bool
body string

expectedStatus int
expectedHeaders http.Header
expectedBody string
}

func (testcase *HandlerTestCase) NewRequest(t *testing.T) *http.Request {
// test request setup
var (
req *http.Request
err error
)
if testcase.nilBody {
req, err = http.NewRequest(testcase.method, testcase.url, nil)
} else {
req, err = http.NewRequest(testcase.method, testcase.url, strings.NewReader(testcase.body))
}
if err != nil {
t.Fatal(err)
}
req = mux.SetURLVars(req, testcase.urlRouteVars)

if testcase.headers != nil {
req.Header = *testcase.headers
}
if testcase.authorizeID != "" {
auth, err := ag.getAuthByID(testcase.authorizeID)
if err != nil {
t.Fatal(err)
}
// getAuthHeader requires a content type and body
req.Header.Set("Authorization", hawk.NewRequestAuth(req,
&hawk.Credentials{
ID: auth.ID,
Key: auth.Key,
Hash: sha256.New},
0).RequestHeader())
}

return req
}

func (testcase *HandlerTestCase) ValidateResponse(t *testing.T, w *httptest.ResponseRecorder) {
if w.Code != testcase.expectedStatus {
t.Fatalf("test case %s: got code %d but expected %d",
testcase.name, w.Code, testcase.expectedStatus)
}
if w.Body.String() != testcase.expectedBody {
t.Fatalf("test case %s: got body %q expected %q", testcase.name, w.Body.String(), testcase.expectedBody)
}
for expectedHeader, expectedHeaderVals := range testcase.expectedHeaders {
vals, ok := w.Header()[expectedHeader]
if !ok {
t.Fatalf("test case %s: expected header %q not found", testcase.name, expectedHeader)
}
if strings.Join(vals, "") != strings.Join(expectedHeaderVals, "") {
t.Fatalf("test case %s: header vals %q did not match expected %q ", testcase.name, vals, expectedHeaderVals)
}
}
}

func (testcase *HandlerTestCase) Run(t *testing.T, handler func(http.ResponseWriter, *http.Request)) {
// test request setup
var req = testcase.NewRequest(t)

// run the request
w := httptest.NewRecorder()
handler(w, req)

// validate response
testcase.ValidateResponse(t, w)
}

func TestBadRequest(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -271,38 +361,41 @@ func TestLBHeartbeat(t *testing.T) {
}
}

func checkHeartbeatReturnsExpectedStatusAndBody(t *testing.T, name, method string, expectedStatusCode int, expectedBody []byte) {
req, err := http.NewRequest(method, "http://foo.bar/__heartbeat__", nil)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
ag.handleHeartbeat(w, req)
if w.Code != expectedStatusCode {
t.Fatalf("test case %s failed with code %d but %d was expected",
name, w.Code, expectedStatusCode)
}
if !bytes.Equal(w.Body.Bytes(), expectedBody) {
t.Fatalf("test case %s returned unexpected heartbeat body %q expected %q", name, w.Body.Bytes(), expectedBody)
}
}

func TestHeartbeat(t *testing.T) {
t.Parallel()

var TESTCASES = []struct {
name string
method string
expectedHTTPStatus int
expectedBody string
}{
{"returns 200 for GET", `GET`, http.StatusOK, "{}"},
{"returns 405 for POST", `POST`, http.StatusMethodNotAllowed, "POST method not allowed; endpoint accepts GET only\r\nrequest-id: -\n"},
{"returns 405 for PUT", `PUT`, http.StatusMethodNotAllowed, "PUT method not allowed; endpoint accepts GET only\r\nrequest-id: -\n"},
{"returns 405 for HEAD", `HEAD`, http.StatusMethodNotAllowed, "HEAD method not allowed; endpoint accepts GET only\r\nrequest-id: -\n"},
var TESTCASES = []HandlerTestCase{
{
name: "returns 200 for GET",
method: "GET",
url: "http://foo.bar/__heartbeat__",
expectedStatus: http.StatusOK,
expectedBody: "{}",
},
{
name: "returns 405 for POST",
method: "POST",
url: "http://foo.bar/__heartbeat__",
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "POST method not allowed; endpoint accepts GET only\r\nrequest-id: -\n",
},
{
name: "returns 405 for PUT",
method: "PUT",
url: "http://foo.bar/__heartbeat__",
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "PUT method not allowed; endpoint accepts GET only\r\nrequest-id: -\n",
},
{
name: "returns 405 for HEAD",
method: "HEAD",
url: "http://foo.bar/__heartbeat__",
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "HEAD method not allowed; endpoint accepts GET only\r\nrequest-id: -\n",
},
}
for _, testcase := range TESTCASES {
checkHeartbeatReturnsExpectedStatusAndBody(t, testcase.name, testcase.method, testcase.expectedHTTPStatus, []byte((testcase.expectedBody)))
testcase.Run(t, ag.handleHeartbeat)
}
}

Expand All @@ -313,9 +406,14 @@ func TestHeartbeatChecksHSMStatusFails(t *testing.T) {
hsmSignerConf: &ag.getSigners()[0].(*contentsignature.ContentSigner).Configuration,
}

expectedStatus := http.StatusInternalServerError
expectedBody := []byte("{\"hsmAccessible\":false}")
checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 500 for GET with HSM inaccessible", `GET`, expectedStatus, expectedBody)
var testcase = HandlerTestCase{
name: "returns 500 for GET with HSM inaccessible",
method: "GET",
url: "http://foo.bar/__heartbeat__",
expectedStatus: http.StatusInternalServerError,
expectedBody: "{\"hsmAccessible\":false}",
}
testcase.Run(t, ag.handleHeartbeat)

ag.heartbeatConf = nil
}
Expand All @@ -324,9 +422,14 @@ func TestHeartbeatChecksHSMStatusFailsWhenNotConfigured(t *testing.T) {
// NB: do not run in parallel with TestHeartbeat*
ag.heartbeatConf = nil

expectedStatus := http.StatusInternalServerError
expectedBody := []byte("Missing heartbeat config\r\nrequest-id: -\n")
checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 500 for GET without heartbeat config HSM", `GET`, expectedStatus, expectedBody)
var testcase = HandlerTestCase{
name: "returns 500 for GET without heartbeat config HSM",
method: "GET",
url: "http://foo.bar/__heartbeat__",
expectedStatus: http.StatusInternalServerError,
expectedBody: "Missing heartbeat config\r\nrequest-id: -\n",
}
testcase.Run(t, ag.handleHeartbeat)
}

func TestHeartbeatChecksDBStatusOKAndTimesout(t *testing.T) {
Expand All @@ -346,25 +449,39 @@ func TestHeartbeatChecksDBStatusOKAndTimesout(t *testing.T) {
DBCheckTimeout: 2 * time.Second,
}

// check OK run locally requires running DB container
expectedStatus := http.StatusOK
expectedBody := []byte("{\"dbAccessible\":true}")
checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 200 for GET with DB accessible", `GET`, expectedStatus, expectedBody)
var dbAccessibleTestCase = HandlerTestCase{
name: "returns 200 for GET with DB accessible",
method: "GET",
url: "http://foo.bar/__heartbeat__",
expectedStatus: http.StatusOK,
expectedBody: "{\"dbAccessible\":true}",
}
dbAccessibleTestCase.Run(t, ag.handleHeartbeat)

// drop timeout
ag.heartbeatConf.DBCheckTimeout = 1 * time.Nanosecond
// check DB request times out
expectedStatus = http.StatusOK
expectedBody = []byte("{\"dbAccessible\":false}")
checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 200 for GET with DB time out", `GET`, expectedStatus, expectedBody)
var dbTimeoutTestCase = HandlerTestCase{
name: "returns 200 for GET with DB time out",
method: "GET",
url: "http://foo.bar/__heartbeat__",
expectedStatus: http.StatusOK,
expectedBody: "{\"dbAccessible\":false}",
}
dbTimeoutTestCase.Run(t, ag.handleHeartbeat)

// restore longer timeout and close the DB connection
ag.heartbeatConf.DBCheckTimeout = 1 * time.Second
db.Close()
// check DB request still fails
expectedStatus = http.StatusOK
expectedBody = []byte("{\"dbAccessible\":false}")
checkHeartbeatReturnsExpectedStatusAndBody(t, "returns 200 for GET with DB inaccessible", `GET`, expectedStatus, expectedBody)
var dbOfflineTestCase = HandlerTestCase{
name: "returns 200 for GET with DB inaccessible",
method: "GET",
url: "http://foo.bar/__heartbeat__",
expectedStatus: http.StatusOK,
expectedBody: "{\"dbAccessible\":false}",
}
dbOfflineTestCase.Run(t, ag.handleHeartbeat)

ag.db = nil
}
Expand Down Expand Up @@ -572,96 +689,6 @@ func TestDebug(t *testing.T) {
}
}

type HandlerTestCase struct {
name string
method string
url string

// urlRouteVars are https://pkg.go.dev/github.com/gorilla/mux#Vars
// as configured with the handler at /config/{keyid:[a-zA-Z0-9-_]{1,64}}
// there should only be a keyid var and it should match the url value
urlRouteVars map[string]string

// headers are additional http headers to set
headers *http.Header

// user/auth ID to build an Authorization header for
authorizeID string
nilBody bool
body string

expectedStatus int
expectedHeaders http.Header
expectedBody string
}

func (testcase *HandlerTestCase) NewRequest(t *testing.T) *http.Request {
// test request setup
var (
req *http.Request
err error
)
if testcase.nilBody {
req, err = http.NewRequest(testcase.method, testcase.url, nil)
} else {
req, err = http.NewRequest(testcase.method, testcase.url, strings.NewReader(testcase.body))
}
if err != nil {
t.Fatal(err)
}
req = mux.SetURLVars(req, testcase.urlRouteVars)

if testcase.headers != nil {
req.Header = *testcase.headers
}
if testcase.authorizeID != "" {
auth, err := ag.getAuthByID(testcase.authorizeID)
if err != nil {
t.Fatal(err)
}
// getAuthHeader requires a content type and body
req.Header.Set("Authorization", hawk.NewRequestAuth(req,
&hawk.Credentials{
ID: auth.ID,
Key: auth.Key,
Hash: sha256.New},
0).RequestHeader())
}

return req
}

func (testcase* HandlerTestCase) ValidateResponse(t *testing.T, w *httptest.ResponseRecorder) {
if w.Code != testcase.expectedStatus {
t.Fatalf("test case %s: got code %d but expected %d",
testcase.name, w.Code, testcase.expectedStatus)
}
if w.Body.String() != testcase.expectedBody {
t.Fatalf("test case %s: got body %q expected %q", testcase.name, w.Body.String(), testcase.expectedBody)
}
for expectedHeader, expectedHeaderVals := range testcase.expectedHeaders {
vals, ok := w.Header()[expectedHeader]
if !ok {
t.Fatalf("test case %s: expected header %q not found", testcase.name, expectedHeader)
}
if strings.Join(vals, "") != strings.Join(expectedHeaderVals, "") {
t.Fatalf("test case %s: header vals %q did not match expected %q ", testcase.name, vals, expectedHeaderVals)
}
}
}

func (testcase* HandlerTestCase) Run(t *testing.T, handler func(http.ResponseWriter, *http.Request)) {
// test request setup
var req = testcase.NewRequest(t)

// run the request
w := httptest.NewRecorder()
handler(w, req)

// validate response
testcase.ValidateResponse(t, w)
}

func TestHandleGetAuthKeyIDs(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 4e4af28

Please sign in to comment.