diff --git a/api/handler/handler.go b/api/handler/handler.go index 2b1de88..6704a77 100644 --- a/api/handler/handler.go +++ b/api/handler/handler.go @@ -161,6 +161,9 @@ type Host interface { // FuncRemoveHeader with HeaderKindResponseTrailers. This panics if // FeatureTrailers is not supported. RemoveResponseTrailer(ctx context.Context, name string) + + // GetSourceAddr supports the WebAssembly function export FuncGetSourceAddr. + GetSourceAddr(ctx context.Context) string } // eofReader is safer than reading from os.DevNull as it can never overrun @@ -206,3 +209,4 @@ func (UnimplementedHost) GetResponseTrailerValues(context.Context, string) (valu func (UnimplementedHost) SetResponseTrailerValue(context.Context, string, string) {} func (UnimplementedHost) AddResponseTrailerValue(context.Context, string, string) {} func (UnimplementedHost) RemoveResponseTrailer(context.Context, string) {} +func (UnimplementedHost) GetSourceAddr(context.Context) string { return "1.1.1.1:12345" } diff --git a/api/handler/wasm.go b/api/handler/wasm.go index 48c2e48..3608464 100644 --- a/api/handler/wasm.go +++ b/api/handler/wasm.go @@ -316,4 +316,10 @@ const ( // // TODO: document on http-wasm-abi FuncSetStatusCode = "set_status_code" + + // FuncGetSourceAddr writes the SourceAddr to memory if it isn't larger than BufLimit. + // The result is its length in bytes. Ex. "1.1.1.1:12345" or "[fe80::101e:2bdf:8bfb:b97e]:12345" + // + // TODO: document on http-wasm-abi + FuncGetSourceAddr = "get_source_addr" ) diff --git a/handler/middleware.go b/handler/middleware.go index cfdcd37..6cf49d9 100644 --- a/handler/middleware.go +++ b/handler/middleware.go @@ -565,6 +565,17 @@ func (m *middleware) writeBody(ctx context.Context, mod wazeroapi.Module, params writeBody(mod, buf, bufLen, w) } +// getSourceAddr implements the WebAssembly host function handler.FuncGetSourceAddr. +func (m *middleware) getSourceAddr(ctx context.Context, mod wazeroapi.Module, stack []uint64) { + buf := uint32(stack[0]) + bufLimit := handler.BufLimit(stack[1]) + + method := m.host.GetSourceAddr(ctx) + methodLen := writeStringIfUnderLimit(mod.Memory(), buf, bufLimit, method) + + stack[0] = uint64(methodLen) +} + func writeBody(mod wazeroapi.Module, buf, bufLen uint32, w io.Writer) { // buf_len 0 means to overwrite with nothing var b []byte @@ -694,6 +705,9 @@ func (m *middleware) instantiateHost(ctx context.Context) (wazeroapi.Module, err WithGoModuleFunction(wazeroapi.GoModuleFunc(m.writeBody), []wazeroapi.ValueType{i32, i32, i32}, []wazeroapi.ValueType{}). WithParameterNames("kind", "body", "body_len").Export(handler.FuncWriteBody). NewFunctionBuilder(). + WithGoModuleFunction(wazeroapi.GoModuleFunc(m.getSourceAddr), []wazeroapi.ValueType{i32, i32}, []wazeroapi.ValueType{i32}). + WithParameterNames("buf", "buf_limit").Export(handler.FuncGetSourceAddr). + NewFunctionBuilder(). WithGoFunction(wazeroapi.GoFunc(m.getStatusCode), []wazeroapi.ValueType{}, []wazeroapi.ValueType{i32}). WithParameterNames().Export(handler.FuncGetStatusCode). NewFunctionBuilder(). diff --git a/handler/nethttp/host.go b/handler/nethttp/host.go index a73e3f5..8ea2538 100644 --- a/handler/nethttp/host.go +++ b/handler/nethttp/host.go @@ -346,3 +346,9 @@ func addTrailer(header http.Header, name string, value string) { func removeTrailer(header http.Header, name string) { header.Del(http.TrailerPrefix + name) } + +// GetSourceAddr implements the same method as documented on handler.Host. +func (host) GetSourceAddr(ctx context.Context) string { + r := requestStateFromContext(ctx).r + return r.RemoteAddr +} diff --git a/handler/nethttp/host_test.go b/handler/nethttp/host_test.go index 1660fee..ff22a80 100644 --- a/handler/nethttp/host_test.go +++ b/handler/nethttp/host_test.go @@ -17,6 +17,7 @@ func Test_host(t *testing.T) { newCtx := func(features handler.Features) (context.Context, handler.Features) { // The below configuration supports all features. r, _ := http.NewRequest("GET", "", bytes.NewReader(nil)) + r.RemoteAddr = "1.2.3.4:12345" w := &bufferingResponseWriter{delegate: &httptest.ResponseRecorder{HeaderMap: map[string][]string{}}} return context.WithValue(testCtx, requestStateKey{}, &requestState{r: r, w: w}), features } diff --git a/tck/guest/go.mod b/tck/guest/go.mod index 1240234..874707f 100644 --- a/tck/guest/go.mod +++ b/tck/guest/go.mod @@ -2,4 +2,4 @@ module github.com/anuraaga/http-wasm-tck/guest go 1.19 -require github.com/http-wasm/http-wasm-guest-tinygo v0.1.1 +require github.com/http-wasm/http-wasm-guest-tinygo v0.3.1-0.20231031134125-487a6e2eec5e diff --git a/tck/guest/go.sum b/tck/guest/go.sum index 5e83e86..e1769d0 100644 --- a/tck/guest/go.sum +++ b/tck/guest/go.sum @@ -1,2 +1,2 @@ -github.com/http-wasm/http-wasm-guest-tinygo v0.1.1 h1:7L+MhMNDVsUAqmElG64JEZ0n1yGCKTx0w01N1b08Xhc= -github.com/http-wasm/http-wasm-guest-tinygo v0.1.1/go.mod h1:roTs1mkyGDe1CUzrL8JUXSbPNYUHnnWKMCG6epmRAZY= +github.com/http-wasm/http-wasm-guest-tinygo v0.3.1-0.20231031134125-487a6e2eec5e h1:mUEuIuH+XAtp5tkJPpwN5ZAtgvt64tYaYsk/KoK42lQ= +github.com/http-wasm/http-wasm-guest-tinygo v0.3.1-0.20231031134125-487a6e2eec5e/go.mod h1:zcKr7h/t5ha2ZWIMwV4iOqhfC/qno/tNPYgybVkn/MQ= diff --git a/tck/guest/main.go b/tck/guest/main.go index af9e147..fecbeac 100644 --- a/tck/guest/main.go +++ b/tck/guest/main.go @@ -104,6 +104,8 @@ func (h *handler) handleRequest(req api.Request, resp api.Response) (next bool, next, reqCtx = h.testReadBody(req, resp, strings.Repeat("a", 4096)) case "read_body/request/xlarge": next, reqCtx = h.testReadBody(req, resp, strings.Repeat("a", 5000)) + case "get_source_addr": + next, reqCtx = h.testGetSourceAddr(req, resp, "127.0.0.1") default: fail(resp, "unknown x-httpwasm-test-id") } @@ -214,6 +216,24 @@ func (h *handler) testReadBody(req api.Request, resp api.Response, expectedBody return true, 0 } +func (h *handler) testGetSourceAddr(req api.Request, resp api.Response, expectedAddr string) (next bool, reqCtx uint32) { + addr := req.GetSourceAddr() + raw := strings.Split(addr, ":") + if len(raw) != 2 { + fail(resp, fmt.Sprintf("get_source_addr: unknown colon count %s", req.GetSourceAddr())) + return + } + if raw[0] != expectedAddr { + fail(resp, fmt.Sprintf("get_source_addr: want %s, have %s", expectedAddr, req.GetSourceAddr())) + return + } + if len(raw[1]) <= 0 || len(raw[1]) > 5 { + fail(resp, fmt.Sprintf("get_source_addr: could not find port number '%s' from %s", raw[1], req.GetSourceAddr())) + return + } + return true, 0 +} + func fail(resp api.Response, msg string) { resp.SetStatusCode(500) resp.Headers().Set("x-httpwasm-tck-failed", msg) diff --git a/tck/run.go b/tck/run.go index 0044a05..568c60a 100644 --- a/tck/run.go +++ b/tck/run.go @@ -53,6 +53,7 @@ func Run(t *testing.T, client *http.Client, url string) { r.testAddHeaderValueRequest() r.testRemoveHeaderRequest() r.testReadBodyRequest() + r.testGetSourceAddr() } type testRunner struct { @@ -493,6 +494,24 @@ func (r *testRunner) testRemoveHeaderRequest() { } } +func (r *testRunner) testGetSourceAddr() { + hostFn := handler.FuncGetSourceAddr + + r.t.Run(hostFn, func(t *testing.T) { + req, err := http.NewRequest("GET", r.url, nil) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("x-httpwasm-tck-testid", hostFn) + resp, err := r.client.Do(req) + if err != nil { + t.Error(err) + } + checkResponse(t, resp) + }) +} + func checkResponse(t *testing.T, resp *http.Response) string { t.Helper() diff --git a/tck/tck.wasm b/tck/tck.wasm index ac0fb21..fc70d1b 100755 Binary files a/tck/tck.wasm and b/tck/tck.wasm differ diff --git a/testing/handlertest/testhandler.go b/testing/handlertest/testhandler.go index e382c46..2392a0d 100644 --- a/testing/handlertest/testhandler.go +++ b/testing/handlertest/testhandler.go @@ -49,6 +49,7 @@ func HostTest(t *testing.T, h handler.Host, newCtx func(handler.Features) (conte ht.testResponseHeaders() ht.testResponseBody() ht.testResponseTrailers() + ht.testSourceAddr() if len(ht.errText) == 0 { return nil @@ -64,6 +65,18 @@ type hostTester struct { errText []byte } +func (h *hostTester) testSourceAddr() { + ctx, _ := h.newCtx(0) // no features required + + h.t.Run("GetSourceAddr", func(t *testing.T) { + addr := h.h.GetSourceAddr(ctx) + want := "1.2.3.4:12345" + if addr != want { + t.Errorf("unexpected default source addr, want: %v, have: %v", want, addr) + } + }) +} + func (h *hostTester) testMethod() { ctx, _ := h.newCtx(0) // no features required