diff --git a/cmd/handlers/handlers_test.go b/cmd/handlers/handlers_test.go index bbe413b4..8e3a5941 100644 --- a/cmd/handlers/handlers_test.go +++ b/cmd/handlers/handlers_test.go @@ -36,8 +36,13 @@ func setupTestDependencies() *TestDependencies { } } +// Writes an OK status to the response +func mockHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + w.WriteHeader(http.StatusOK) +} + func TestProcessReceipt(t *testing.T) { - deps := setupTestDependencies() + d := setupTestDependencies() tests := []struct { name string input ReceiptInput @@ -66,7 +71,7 @@ func TestProcessReceipt(t *testing.T) { resp := httptest.NewRecorder() // Test ProcessReceipt function - deps.handlers.ProcessReceipt(resp, req) + d.handlers.ProcessReceipt(resp, req) // Check the response status if status := resp.Code; status != entry.expectedStatus { @@ -106,6 +111,7 @@ func TestProcessReceipt(t *testing.T) { } func TestGetReceiptPoints(t *testing.T) { + d := setupTestDependencies() router := httprouter.New() // Register the route router.GET("/receipts/:id/points", mockHandler) @@ -123,64 +129,7 @@ func TestGetReceiptPoints(t *testing.T) { for _, entry := range tests { t.Run(entry.name, func(t *testing.T) { - receiptStore.Insert(*SimpleReceipt) - - // Create request and response - req := httptest.NewRequest(http.MethodGet, entry.url, nil) - // Create request context that will enable id extraction - params := httprouter.Params{ - httprouter.Param{Key: "id", Value: entry.ID}, - } - ctx := context.WithValue(req.Context(), httprouter.ParamsKey, params) - req = req.WithContext(ctx) - // Create response - resp := httptest.NewRecorder() - - // Serve the request - router.ServeHTTP(resp, req) - - // Test GetReceiptPoints function - handlers.GetReceiptPoints(resp, req) - - // Check response status - if resp.Code == http.StatusOK && entry.status == http.StatusOK { - var response PointsResponse - // Decode the request body - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { - t.Fatalf("Failed to decode response: %v", err) - } - // Check if the points match - if response.Points != SimpleReceipt.Points { - t.Fatalf("Expected %d points, received %d", SimpleReceipt.Points, response.Points) - } - } - - t.Cleanup(func() { - receiptStore = models.NewStore() - }) - }) - } -} - -func TestGetReceiptPoints(t *testing.T) { - router := httprouter.New() - // Register the route - router.GET("/receipts/:id/points", mockHandler) - - tests := []struct { - name string - ID string - url string - status int - }{ - {"Valid receipt id", SimpleReceipt.ID, "/receipts/123-qwe-456-rty-7890/points", http.StatusOK}, - {"Invalid receipt id", "hello-world", "/receipts/hello-world/points", http.StatusNotFound}, - {"Invalid request url", SimpleReceipt.ID, "/hello/123-qwe-456-rty-7890/world", http.StatusNotFound}, - } - - for _, entry := range tests { - t.Run(entry.name, func(t *testing.T) { - receiptStore.Insert(*SimpleReceipt) + d.receiptStore.Insert(*SimpleReceipt) // Create request and response req := httptest.NewRequest(http.MethodGet, entry.url, nil) @@ -197,7 +146,7 @@ func TestGetReceiptPoints(t *testing.T) { router.ServeHTTP(resp, req) // Test GetReceiptPoints function - handlers.GetReceiptPoints(resp, req) + d.handlers.GetReceiptPoints(resp, req) // Check response status if resp.Code == http.StatusOK && entry.status == http.StatusOK { @@ -213,7 +162,7 @@ func TestGetReceiptPoints(t *testing.T) { } t.Cleanup(func() { - receiptStore = models.NewStore() + d.receiptStore = models.NewStore() }) }) }