Skip to content

Commit d19908d

Browse files
Remove direct dependency on gorilla mux (#384)
Signed-off-by: Daniel Weiße <[email protected]>
1 parent 83da12e commit d19908d

File tree

5 files changed

+47
-33
lines changed

5 files changed

+47
-33
lines changed

Diff for: coordinator/server/client_api.go

+14
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,20 @@ func (s *clientAPIServer) secretsPost(w http.ResponseWriter, r *http.Request) {
409409
writeJSON(w, nil)
410410
}
411411

412+
func (s *clientAPIServer) handleGetPost(getHandler, postHandler func(http.ResponseWriter, *http.Request),
413+
) func(http.ResponseWriter, *http.Request) {
414+
return func(w http.ResponseWriter, r *http.Request) {
415+
switch r.Method {
416+
case http.MethodGet:
417+
getHandler(w, r)
418+
case http.MethodPost:
419+
postHandler(w, r)
420+
default:
421+
s.methodNotAllowedHandler(w, r)
422+
}
423+
}
424+
}
425+
412426
func (s *clientAPIServer) methodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
413427
writeJSONError(w, "", http.StatusMethodNotAllowed)
414428
}

Diff for: coordinator/server/metrics.go

+15-14
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,21 @@ package server
99
import (
1010
"net/http"
1111

12-
"github.com/gorilla/mux"
1312
"github.com/prometheus/client_golang/prometheus"
1413
"github.com/prometheus/client_golang/prometheus/promauto"
1514
"github.com/prometheus/client_golang/prometheus/promhttp"
1615
)
1716

1817
// serveMux is an interface of an HTTP request multiplexer.
1918
type serveMux interface {
20-
Handle(pattern string, handler http.Handler) *mux.Route
21-
HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) *mux.Route
19+
Handle(pattern string, handler http.Handler)
20+
HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request))
2221
ServeHTTP(w http.ResponseWriter, r *http.Request)
2322
}
2423

2524
// httpMetrics is a struct of metrics for Prometheus to collect for each endpoint.
2625
type httpMetrics struct {
27-
reqest *prometheus.CounterVec
26+
request *prometheus.CounterVec
2827
duration *prometheus.HistogramVec
2928
requestSize *prometheus.HistogramVec
3029
responseSize *prometheus.HistogramVec
@@ -35,7 +34,7 @@ type httpMetrics struct {
3534
// and registres them using the given factory.
3635
func newHTTPMetrics(factory *promauto.Factory, namespace string, subsystem string, constLabels map[string]string) *httpMetrics {
3736
return &httpMetrics{
38-
reqest: factory.NewCounterVec(
37+
request: factory.NewCounterVec(
3938
prometheus.CounterOpts{
4039
Namespace: namespace,
4140
Subsystem: subsystem,
@@ -93,7 +92,7 @@ func newHTTPMetrics(factory *promauto.Factory, namespace string, subsystem strin
9392
// promServeMux is a wrapper around mux.Router with additional instrumentation to
9493
// gather Prometheus metrics.
9594
type promServeMux struct {
96-
router *mux.Router
95+
router *http.ServeMux
9796
promFactory *promauto.Factory
9897
metrics map[string]*httpMetrics
9998
namespace string
@@ -104,7 +103,7 @@ type promServeMux struct {
104103
// namespace and subsystem are used to name the exposed metrics.
105104
func newPromServeMux(factory *promauto.Factory, namespace string, subsystem string) *promServeMux {
106105
return &promServeMux{
107-
router: mux.NewRouter(),
106+
router: http.NewServeMux(),
108107
promFactory: factory,
109108
metrics: make(map[string]*httpMetrics),
110109
namespace: namespace,
@@ -114,22 +113,22 @@ func newPromServeMux(factory *promauto.Factory, namespace string, subsystem stri
114113

115114
// Handle is a wrapper around (*mux.Router) Handle form the http package
116115
// A chain of prometheus instrumentation collects metrics for the given handler.
117-
func (p *promServeMux) Handle(pattern string, handler http.Handler) *mux.Route {
116+
func (p *promServeMux) Handle(pattern string, handler http.Handler) {
118117
if p.metrics[pattern] == nil {
119118
constLabels := map[string]string{
120119
"path": pattern,
121120
}
122121
p.metrics[pattern] = newHTTPMetrics(p.promFactory, p.namespace, p.subsystem, constLabels)
123122
}
124-
return p.router.Handle(pattern, p.metricsMiddleware(pattern, handler))
123+
p.router.Handle(pattern, p.metricsMiddleware(pattern, handler))
125124
}
126125

127126
// HandleFunc registers the handler function for the given pattern.
128-
func (p *promServeMux) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) *mux.Route {
127+
func (p *promServeMux) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
129128
if handler == nil {
130129
panic("promServerMux: http: nil handler")
131130
}
132-
return p.Handle(pattern, http.HandlerFunc(handler))
131+
p.Handle(pattern, http.HandlerFunc(handler))
133132
}
134133

135134
// ServeHTTP is a wrapper around (*mux.Router) ServeHttp form the http package.
@@ -140,7 +139,7 @@ func (p *promServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
140139
// metricsMiddelware returns the handed next handler wrapped in a bunch of prometheus metric handlers.
141140
func (p *promServeMux) metricsMiddleware(pattern string, next http.Handler) http.Handler {
142141
return promhttp.InstrumentHandlerDuration(p.metrics[pattern].duration,
143-
promhttp.InstrumentHandlerCounter(p.metrics[pattern].reqest,
142+
promhttp.InstrumentHandlerCounter(p.metrics[pattern].request,
144143
promhttp.InstrumentHandlerRequestSize(p.metrics[pattern].requestSize,
145144
promhttp.InstrumentHandlerResponseSize(p.metrics[pattern].responseSize,
146145
promhttp.InstrumentHandlerInFlight(p.metrics[pattern].inflight, next),
@@ -152,9 +151,11 @@ func (p *promServeMux) metricsMiddleware(pattern string, next http.Handler) http
152151

153152
// setMethodNOtAllowedHandler sets f as instrumented handler for the mux.Router.
154153
func (p *promServeMux) setMethodNotAllowedHandler(f func(http.ResponseWriter, *http.Request)) {
155-
p.router.MethodNotAllowedHandler = http.HandlerFunc(
154+
p.router.HandleFunc(
155+
"/",
156156
func(w http.ResponseWriter, r *http.Request) {
157157
handler := p.metricsMiddleware(r.URL.Path, http.HandlerFunc(f))
158158
handler.ServeHTTP(w, r)
159-
})
159+
},
160+
)
160161
}

Diff for: coordinator/server/metrics_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@ func TestClientApiRequestMetrics(t *testing.T) {
6969
mux := CreateServeMux(api, &fac)
7070

7171
metrics := mux.(*promServeMux).metrics[tc.target]
72-
assert.Equal(0, promtest.CollectAndCount(metrics.reqest))
73-
assert.Equal(float64(0), promtest.ToFloat64(metrics.reqest.WithLabelValues(tc.expectedStatusCode, strings.ToLower(tc.method))))
72+
assert.Equal(0, promtest.CollectAndCount(metrics.request))
73+
assert.Equal(float64(0), promtest.ToFloat64(metrics.request.WithLabelValues(tc.expectedStatusCode, strings.ToLower(tc.method))))
7474

7575
for i := 1; i < 6; i++ {
7676
req := httptest.NewRequest(tc.method, tc.target, nil)
7777
resp := httptest.NewRecorder()
7878
mux.ServeHTTP(resp, req)
79-
assert.Equal(1, promtest.CollectAndCount(metrics.reqest))
80-
assert.Equal(float64(i), promtest.ToFloat64(metrics.reqest.WithLabelValues(tc.expectedStatusCode, strings.ToLower(tc.method))))
79+
assert.Equal(1, promtest.CollectAndCount(metrics.request))
80+
assert.Equal(float64(i), promtest.ToFloat64(metrics.request.WithLabelValues(tc.expectedStatusCode, strings.ToLower(tc.method))))
8181
}
8282
})
8383
}

Diff for: coordinator/server/server.go

+13-14
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"github.com/edgelesssys/marblerun/coordinator/rpc"
2020
"github.com/edgelesssys/marblerun/coordinator/state"
2121
"github.com/edgelesssys/marblerun/coordinator/user"
22-
"github.com/gorilla/mux"
2322
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
2423
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
2524
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
@@ -96,21 +95,21 @@ func CreateServeMux(api clientAPI, promFactory *promauto.Factory) serveMux {
9695
server := clientAPIServer{api}
9796
var router serveMux
9897
if promFactory != nil {
99-
router = newPromServeMux(promFactory, "server", "client_api")
100-
router.(*promServeMux).setMethodNotAllowedHandler(server.methodNotAllowedHandler)
98+
muxRouter := newPromServeMux(promFactory, "server", "client_api")
99+
muxRouter.setMethodNotAllowedHandler(server.methodNotAllowedHandler)
100+
router = muxRouter
101101
} else {
102-
router = mux.NewRouter()
103-
router.(*mux.Router).MethodNotAllowedHandler = http.HandlerFunc(server.methodNotAllowedHandler)
102+
muxRouter := http.NewServeMux()
103+
muxRouter.HandleFunc("/", server.methodNotAllowedHandler)
104+
router = muxRouter
104105
}
105-
router.HandleFunc("/status", server.statusGet).Methods("GET")
106-
router.HandleFunc("/manifest", server.manifestGet).Methods("GET")
107-
router.HandleFunc("/manifest", server.manifestPost).Methods("POST")
108-
router.HandleFunc("/quote", server.quoteGet).Methods("GET")
109-
router.HandleFunc("/recover", server.recoverPost).Methods("POST")
110-
router.HandleFunc("/update", server.updateGet).Methods("GET")
111-
router.HandleFunc("/update", server.updatePost).Methods("POST")
112-
router.HandleFunc("/secrets", server.secretsPost).Methods("POST")
113-
router.HandleFunc("/secrets", server.secretsGet).Methods("GET")
106+
107+
router.HandleFunc("/manifest", server.handleGetPost(server.manifestGet, server.manifestPost))
108+
router.HandleFunc("/update", server.handleGetPost(server.updateGet, server.updatePost))
109+
router.HandleFunc("/secrets", server.handleGetPost(server.secretsGet, server.secretsPost))
110+
router.HandleFunc("/status", server.handleGetPost(server.statusGet, server.methodNotAllowedHandler))
111+
router.HandleFunc("/quote", server.handleGetPost(server.quoteGet, server.methodNotAllowedHandler))
112+
router.HandleFunc("/recover", server.handleGetPost(server.methodNotAllowedHandler, server.recoverPost))
114113
return router
115114
}
116115

Diff for: go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ require (
99
github.com/gofrs/flock v0.8.1
1010
github.com/google/go-cmp v0.5.9
1111
github.com/google/uuid v1.3.0
12-
github.com/gorilla/mux v1.8.0
1312
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
1413
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
1514
github.com/prometheus/client_golang v1.14.0
@@ -76,6 +75,7 @@ require (
7675
github.com/google/gnostic v0.6.9 // indirect
7776
github.com/google/gofuzz v1.2.0 // indirect
7877
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
78+
github.com/gorilla/mux v1.8.0 // indirect
7979
github.com/gosuri/uitable v0.0.4 // indirect
8080
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect
8181
github.com/huandu/xstrings v1.3.3 // indirect

0 commit comments

Comments
 (0)