Skip to content

Commit e4552f4

Browse files
committed
Minor refactor + code comments
1 parent 6604036 commit e4552f4

File tree

5 files changed

+55
-68
lines changed

5 files changed

+55
-68
lines changed

cache.go

+18-10
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,21 @@ import (
88
"time"
99
)
1010

11-
// Basic in-memory cache for response data (currently only a named type for map)
11+
// Simple in-memory cache middleware.
12+
//
13+
// Response are indexed by the request URL path and are stored
14+
// only if they contain the "Expires" header.
1215
type CacheMiddleware struct {
1316
cache map[string]*cachedResponse
1417
}
1518

19+
// A cached response with its given Expires
1620
type cachedResponse struct {
1721
responseData []byte
1822
expires int64
1923
}
2024

21-
// Creates an empty in-memory cache middleware
25+
// Creates an empty cache middleware
2226
func NewCacheMiddleware() *CacheMiddleware {
2327
var c CacheMiddleware
2428
c.cache = make(map[string]*cachedResponse)
@@ -31,23 +35,22 @@ func (c *CacheMiddleware) ProcessRequest(rw http.ResponseWriter, req *http.Reque
3135

3236
entry, ok := c.cache[req.URL.Path]
3337

34-
// Not on cache or expired response
38+
// Response is not on cache or already expired
3539
if !ok || entry.expires <= time.Now().Unix() {
36-
log.Printf("Cache MISS")
37-
rw.Header().Add("X-Cache-Status", "MISS")
38-
40+
setCacheStatus(rw, "MISS")
3941
delete(c.cache, req.URL.Path) // not necessary?
4042
return
4143
}
4244

43-
// Read cached (and not expired) response
45+
// Read cached response
4446
res, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(entry.responseData)), req)
4547
if err != nil {
46-
log.Fatal(err)
48+
log.Printf("Error reading response stored in cache: %v", err)
49+
setCacheStatus(rw, "MISS")
50+
return
4751
}
4852

49-
log.Printf("Cache HIT")
50-
rw.Header().Add("X-Cache-Status", "HIT")
53+
setCacheStatus(rw, "HIT")
5154
WriteResponse(rw, res)
5255
}
5356

@@ -71,3 +74,8 @@ func (c *CacheMiddleware) ProcessResponse(res *http.Response, req *http.Request)
7174
b := DumpResponse(res)
7275
c.cache[req.URL.Path] = &cachedResponse{responseData: b.Bytes(), expires: t.Unix()}
7376
}
77+
78+
func setCacheStatus(rw http.ResponseWriter, status string) {
79+
log.Printf("Cache %v", status)
80+
rw.Header().Add("X-Cache-Status", status)
81+
}

cache_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@ package main
22

33
import (
44
"bytes"
5-
"io/ioutil"
5+
"io"
66
"net/http"
77
"net/http/httptest"
88
"testing"
99
"time"
1010
)
1111

12-
// Tests processing requests and responses resulting in HIT/MISS
12+
// Tests requests resulting in HIT/MISS
1313
func TestCacheMiddleware(t *testing.T) {
1414
// Setup test
1515
c := NewCacheMiddleware()
@@ -63,6 +63,6 @@ func makeDummyResponse() *http.Response {
6363
ProtoMajor: 1,
6464
ProtoMinor: 0,
6565
Header: make(http.Header),
66-
Body: ioutil.NopCloser(bytes.NewBufferString("Hello World")),
66+
Body: io.NopCloser(bytes.NewBufferString("Hello World")),
6767
}
6868
}

reverso.go

+18-34
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,52 @@
11
package main
22

33
import (
4-
"fmt"
54
"log"
65
"net/http"
76
"net/url"
87
)
98

109
// Reverso is an HTTP handler behaving as a reverse proxy.
1110
//
12-
// Reverso forwards incoming requests to a target server and
13-
// sends the response back to the client.
11+
// Incoming requests are forwarded to the host specified in originURL.
12+
//
13+
// Responses containing the "Expires" header are stored in an in-memory cache
14+
// and served from there on further requests, as long as they do not expire
1415
type Reverso struct {
1516
// Origin server URL to forward requests.
1617
originURL url.URL
1718

18-
// In-memory cache middleware
19+
// In-memory cache middleware to store response data.
1920
cache CacheMiddleware
2021
}
2122

2223
// Handler function to responds to an HTTP request.
2324
func (r *Reverso) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
2425
log.Println(req.Method, req.URL.Path)
2526

26-
r.cache.ProcessRequest(rw, req)
27+
// Fetch from cache if available
28+
r.cache.ProcessRequest(rw, req) // writes X-Cache-Status header
2729

2830
if rw.Header().Get("X-Cache-Status") == "MISS" {
29-
30-
// Fetch request from origin server
31-
res, err := r.fetchFromOrigin(req)
31+
// Modify request host and scheme to point to origin server
32+
req.URL.Host = r.originURL.Host
33+
req.URL.Scheme = r.originURL.Scheme
34+
req.RequestURI = "" // Must be empty for client requests (see field docs in https://pkg.go.dev/net/http#Request)
35+
36+
// Fetch from origin server
37+
log.Printf("Forwarding request to: '%v'", req.URL.String())
38+
res, err := (&http.Client{}).Do(req)
3239
if err != nil {
33-
log.Println(err)
40+
log.Printf("Error: %v", err)
3441
rw.WriteHeader(http.StatusInternalServerError)
3542
return
3643
}
3744
b := DumpResponse(res)
3845

46+
// Process response, stores in cache if contains Expires header
3947
r.cache.ProcessResponse(ReadResponse(b.Bytes(), req), req)
4048

49+
// Write response back
4150
WriteResponse(rw, ReadResponse(b.Bytes(), req))
4251
}
4352
}
44-
45-
func (r *Reverso) fetchFromOrigin(req *http.Request) (*http.Response, error) {
46-
// Modify request to forward to origin server
47-
req.URL.Scheme = r.originURL.Scheme
48-
req.URL.Host = r.originURL.Host
49-
req.RequestURI = "" // Should be empty for client requests (see src/net/http/client.go:217)
50-
51-
log.Printf("Forwarding request to: '%v'", req.URL.String())
52-
53-
// Send request to the origin server
54-
res, err := (&http.Client{}).Do(req)
55-
if err != nil {
56-
return nil, &internalError{err.Error()}
57-
}
58-
59-
return res, nil
60-
}
61-
62-
type internalError struct {
63-
msg string
64-
}
65-
66-
func (e *internalError) Error() string {
67-
return fmt.Sprintf("Error: %v", e.msg)
68-
}

reverso_test.go

+16-20
Original file line numberDiff line numberDiff line change
@@ -10,49 +10,52 @@ import (
1010
"testing"
1111
)
1212

13-
// Checks if client receives body and headers sent from origin
13+
// Checks if client receives response sent by origin
1414
func TestHandleSimpleRequest(t *testing.T) {
1515
const expectedBodyStr string = "Hello from the other side"
1616
const customHeaderKey string = "X-Test-Header"
1717
const customHeaderVal string = "Custom header from origin"
1818

19-
// Mock objects
19+
// Setup
2020
req := httptest.NewRequest(http.MethodGet, "/", nil)
2121
rec := httptest.NewRecorder()
2222
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2323
w.Header().Set(customHeaderKey, customHeaderVal)
2424
fmt.Fprint(w, expectedBodyStr)
2525
}))
2626
defer svr.Close()
27-
28-
// Test reverse proxy
2927
r := &Reverso{originURL: parseServerURL(svr.URL), cache: *NewCacheMiddleware()}
28+
29+
// Test reverse proxy handler
3030
r.ServeHTTP(rec, req)
3131

3232
res := rec.Result()
3333
defer res.Body.Close()
34+
body, err := io.ReadAll(res.Body)
35+
if err != nil {
36+
log.Fatal(err)
37+
}
3438

35-
bodyStr := readAll(res.Body)
36-
if bodyStr != expectedBodyStr {
37-
t.Errorf("Expected '%v', got '%v'", expectedBodyStr, bodyStr)
39+
if string(body) != expectedBodyStr {
40+
t.Errorf("Expected '%v', got '%v'", expectedBodyStr, string(body))
3841
}
3942

4043
if v := res.Header.Get(customHeaderKey); v != customHeaderVal {
41-
t.Errorf("Expected header value '%v', got '%v'", customHeaderVal, v)
44+
t.Errorf("Expected header '%v:%v', got '%v'", customHeaderKey, customHeaderVal, v)
4245
}
4346
}
4447

4548
// Checks if client receives HTTP status code sent by origin
4649
func TestHandleTeapotRequest(t *testing.T) {
47-
// Mock objects
50+
// Setup
4851
req := httptest.NewRequest(http.MethodGet, "/teapot", nil)
4952
rec := httptest.NewRecorder()
5053
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5154
w.WriteHeader(http.StatusTeapot)
5255
}))
5356
defer svr.Close()
5457

55-
// Test reverse proxy
58+
// Test reverse proxy handler
5659
r := &Reverso{originURL: parseServerURL(svr.URL), cache: *NewCacheMiddleware()}
5760
r.ServeHTTP(rec, req)
5861

@@ -63,11 +66,13 @@ func TestHandleTeapotRequest(t *testing.T) {
6366
}
6467
}
6568

66-
// Checks for internal server errors when using an invalid server URL
6769
func TestInvalidOriginURL(t *testing.T) {
70+
// Setup
6871
req := httptest.NewRequest(http.MethodGet, "/", nil)
6972
rec := httptest.NewRecorder()
7073
r := &Reverso{} // empty origin URL
74+
75+
// Test reverse proxy handler
7176
r.ServeHTTP(rec, req)
7277
res := rec.Result()
7378

@@ -84,12 +89,3 @@ func parseServerURL(rawURL string) url.URL {
8489
}
8590
return *serverURL
8691
}
87-
88-
// Read all from r and return content as string
89-
func readAll(r io.Reader) string {
90-
body, err := io.ReadAll(r)
91-
if err != nil {
92-
log.Fatal(err)
93-
}
94-
return string(body)
95-
}

util.go

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ func WriteResponse(rw http.ResponseWriter, res *http.Response) {
1313
// Write header
1414
for key, values := range res.Header {
1515
for _, value := range values {
16-
log.Println(key, value)
1716
rw.Header().Add(key, value)
1817
}
1918
}

0 commit comments

Comments
 (0)