Skip to content

Commit

Permalink
Ensures response is handled with the same guest as the request (#69)
Browse files Browse the repository at this point in the history
Before, we had a bug where an arbitrary module was used for response
processing. This defeated the goal of request context correlation. This
fixes the problem by holding the guest module open until the response
is complete.

Fixes #68

Signed-off-by: Adrian Cole <[email protected]>
  • Loading branch information
codefromthecrypt committed May 23, 2023
1 parent 377347e commit 88a400f
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 57 deletions.
44 changes: 44 additions & 0 deletions api/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,47 @@ type Host interface {
// FeatureTrailers is not supported.
RemoveResponseTrailer(ctx context.Context, name string)
}

// eofReader is safer than reading from os.DevNull as it can never overrun
// operating system file descriptors.
type eofReader struct{}

func (eofReader) Close() (err error) { return }
func (eofReader) Read([]byte) (int, error) { return 0, io.EOF }

type UnimplementedHost struct{}

var _ Host = UnimplementedHost{}

func (UnimplementedHost) EnableFeatures(context.Context, Features) Features { return 0 }
func (UnimplementedHost) GetMethod(context.Context) string { return "GET" }
func (UnimplementedHost) SetMethod(context.Context, string) {}
func (UnimplementedHost) GetURI(context.Context) string { return "" }
func (UnimplementedHost) SetURI(context.Context, string) {}
func (UnimplementedHost) GetProtocolVersion(context.Context) string { return "HTTP/1.1" }
func (UnimplementedHost) GetRequestHeaderNames(context.Context) (names []string) { return }
func (UnimplementedHost) GetRequestHeaderValues(context.Context, string) (values []string) { return }
func (UnimplementedHost) SetRequestHeaderValue(context.Context, string, string) {}
func (UnimplementedHost) AddRequestHeaderValue(context.Context, string, string) {}
func (UnimplementedHost) RemoveRequestHeader(context.Context, string) {}
func (UnimplementedHost) RequestBodyReader(context.Context) io.ReadCloser { return eofReader{} }
func (UnimplementedHost) RequestBodyWriter(context.Context) io.Writer { return io.Discard }
func (UnimplementedHost) GetRequestTrailerNames(context.Context) (names []string) { return }
func (UnimplementedHost) GetRequestTrailerValues(context.Context, string) (values []string) { return }
func (UnimplementedHost) SetRequestTrailerValue(context.Context, string, string) {}
func (UnimplementedHost) AddRequestTrailerValue(context.Context, string, string) {}
func (UnimplementedHost) RemoveRequestTrailer(context.Context, string) {}
func (UnimplementedHost) GetStatusCode(context.Context) uint32 { return 200 }
func (UnimplementedHost) SetStatusCode(context.Context, uint32) {}
func (UnimplementedHost) GetResponseHeaderNames(context.Context) (names []string) { return }
func (UnimplementedHost) GetResponseHeaderValues(context.Context, string) (values []string) { return }
func (UnimplementedHost) SetResponseHeaderValue(context.Context, string, string) {}
func (UnimplementedHost) AddResponseHeaderValue(context.Context, string, string) {}
func (UnimplementedHost) RemoveResponseHeader(context.Context, string) {}
func (UnimplementedHost) ResponseBodyReader(context.Context) io.ReadCloser { return eofReader{} }
func (UnimplementedHost) ResponseBodyWriter(context.Context) io.Writer { return io.Discard }
func (UnimplementedHost) GetResponseTrailerNames(context.Context) (names []string) { return }
func (UnimplementedHost) GetResponseTrailerValues(context.Context, string) (values []string) { return }
func (UnimplementedHost) SetResponseTrailerValue(context.Context, string, string) {}
func (UnimplementedHost) AddResponseTrailerValue(context.Context, string, string) {}
func (UnimplementedHost) RemoveResponseTrailer(context.Context, string) {}
105 changes: 58 additions & 47 deletions handler/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ type Middleware interface {
var _ Middleware = (*middleware)(nil)

type middleware struct {
host handler.Host
runtime wazero.Runtime
hostModule, guestModule wazero.CompiledModule
moduleConfig wazero.ModuleConfig
guestConfig []byte
logger api.Logger
pool sync.Pool
features handler.Features
instanceCounter uint64
host handler.Host
runtime wazero.Runtime
guestModule wazero.CompiledModule
moduleConfig wazero.ModuleConfig
guestConfig []byte
logger api.Logger
pool sync.Pool
features handler.Features
instanceCounter uint64
}

func (m *middleware) Features() handler.Features {
Expand Down Expand Up @@ -83,29 +83,30 @@ func NewMiddleware(ctx context.Context, guest []byte, host handler.Host, opts ..
logger: o.logger,
}

if m.hostModule, err = m.compileHost(ctx); err != nil {
_ = m.Close(ctx)
if m.guestModule, err = m.compileGuest(ctx, guest); err != nil {
_ = wr.Close(ctx)
return nil, err
}

if _, err = wasi_snapshot_preview1.Instantiate(ctx, m.runtime); err != nil {
return nil, fmt.Errorf("wasm: error instantiating wasi: %w", err)
}

// Note: host modules don't use configuration
_, err = m.runtime.InstantiateModule(ctx, m.hostModule, wazero.NewModuleConfig())
if err != nil {
_ = m.runtime.Close(ctx)
return nil, fmt.Errorf("wasm: error instantiating host: %w", err)
}

if m.guestModule, err = m.compileGuest(ctx, guest); err != nil {
_ = m.Close(ctx)
return nil, err
// Detect and handle any host imports or lack thereof.
imports := detectImports(m.guestModule.ImportedFunctions())
switch {
case imports&importWasiP1 != 0:
if _, err = wasi_snapshot_preview1.Instantiate(ctx, m.runtime); err != nil {
_ = wr.Close(ctx)
return nil, fmt.Errorf("wasm: error instantiating wasi: %w", err)
}
fallthrough // proceed to configure any http_handler imports
case imports&importHttpHandler != 0:
if _, err = m.instantiateHost(ctx); err != nil {
_ = wr.Close(ctx)
return nil, fmt.Errorf("wasm: error instantiating host: %w", err)
}
}

// Eagerly add one instance to the pool. Doing so helps to fail fast.
if g, err := m.newGuest(ctx); err != nil {
_ = m.Close(ctx)
_ = wr.Close(ctx)
return nil, err
} else {
m.pool.Put(g)
Expand Down Expand Up @@ -139,11 +140,11 @@ func (m *middleware) HandleRequest(ctx context.Context) (outCtx context.Context,
err = guestErr
return
}
defer m.pool.Put(g)

s := &requestState{features: m.features}
s := &requestState{features: m.features, putPool: m.pool.Put, g: g}
defer func() {
if ctxNext != 0 { // will call the next handler
callNext := ctxNext != 0
if callNext { // will call the next handler
if closeErr := s.closeRequest(); err == nil {
err = closeErr
}
Expand Down Expand Up @@ -173,16 +174,10 @@ func (m *middleware) getOrCreateGuest(ctx context.Context) (*guest, error) {

// HandleResponse implements Middleware.HandleResponse
func (m *middleware) HandleResponse(ctx context.Context, reqCtx uint32, hostErr error) error {
g, err := m.getOrCreateGuest(ctx)
if err != nil {
return err
}
defer m.pool.Put(g)

s := requestStateFromContext(ctx)
defer s.Close()

return g.handleResponse(ctx, reqCtx, hostErr)
return s.g.handleResponse(ctx, reqCtx, hostErr)
}

// Close implements api.Closer
Expand Down Expand Up @@ -529,7 +524,7 @@ func (m *middleware) readBody(ctx context.Context, mod wazeroapi.Module, stack [
panic("unsupported body kind: " + strconv.Itoa(int(kind)))
}

eofLen := readBody(ctx, mod, buf, bufLimit, r)
eofLen := readBody(mod, buf, bufLimit, r)

stack[0] = eofLen
}
Expand Down Expand Up @@ -562,10 +557,10 @@ func (m *middleware) writeBody(ctx context.Context, mod wazeroapi.Module, params
panic("unsupported body kind: " + strconv.Itoa(int(kind)))
}

writeBody(ctx, mod, buf, bufLen, w)
writeBody(mod, buf, bufLen, w)
}

func writeBody(ctx context.Context, mod wazeroapi.Module, buf, bufLen uint32, w io.Writer) {
func writeBody(mod wazeroapi.Module, buf, bufLen uint32, w io.Writer) {
// buf_len 0 means to overwrite with nothing
var b []byte
if bufLen > 0 {
Expand Down Expand Up @@ -596,7 +591,7 @@ func (m *middleware) setStatusCode(ctx context.Context, params []uint64) {
m.host.SetStatusCode(ctx, statusCode)
}

func readBody(ctx context.Context, mod wazeroapi.Module, buf uint32, bufLimit handler.BufLimit, r io.Reader) (eofLen uint64) {
func readBody(mod wazeroapi.Module, buf uint32, bufLimit handler.BufLimit, r io.Reader) (eofLen uint64) {
// buf_limit 0 serves no purpose as implementations won't return EOF on it.
if bufLimit == 0 {
panic(fmt.Errorf("buf_limit==0 reading body"))
Expand Down Expand Up @@ -645,8 +640,8 @@ func mustBeforeNextOrFeature(ctx context.Context, feature handler.Features, op,

const i32, i64 = wazeroapi.ValueTypeI32, wazeroapi.ValueTypeI64

func (m *middleware) compileHost(ctx context.Context) (wazero.CompiledModule, error) {
if compiled, err := m.runtime.NewHostModuleBuilder(handler.HostModule).
func (m *middleware) instantiateHost(ctx context.Context) (wazeroapi.Module, error) {
return m.runtime.NewHostModuleBuilder(handler.HostModule).
NewFunctionBuilder().
WithGoFunction(wazeroapi.GoFunc(m.enableFeatures), []wazeroapi.ValueType{i32}, []wazeroapi.ValueType{i32}).
WithParameterNames("features").Export(handler.FuncEnableFeatures).
Expand Down Expand Up @@ -701,11 +696,7 @@ func (m *middleware) compileHost(ctx context.Context) (wazero.CompiledModule, er
NewFunctionBuilder().
WithGoFunction(wazeroapi.GoFunc(m.setStatusCode), []wazeroapi.ValueType{i32}, []wazeroapi.ValueType{}).
WithParameterNames("status_code").Export(handler.FuncSetStatusCode).
Compile(ctx); err != nil {
return nil, fmt.Errorf("wasm: error compiling host: %w", err)
} else {
return compiled, nil
}
Instantiate(ctx)
}

func mustHeaderMutable(ctx context.Context, op string, kind handler.HeaderKind) {
Expand Down Expand Up @@ -766,3 +757,23 @@ func writeStringIfUnderLimit(mem wazeroapi.Memory, offset, limit handler.BufLimi
mem.WriteString(offset, v)
return
}

type imports uint

const (
importWasiP1 imports = 1 << iota
importHttpHandler
)

func detectImports(importedFns []wazeroapi.FunctionDefinition) (imports imports) {
for _, f := range importedFns {
moduleName, _, _ := f.Import()
switch moduleName {
case handler.HostModule:
imports |= importHttpHandler
case wasi_snapshot_preview1.ModuleName:
imports |= importWasiP1
}
}
return
}
93 changes: 93 additions & 0 deletions handler/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package handler

import (
"context"
_ "embed"
"reflect"
"testing"

"github.com/http-wasm/http-wasm-host-go/api/handler"
"github.com/http-wasm/http-wasm-host-go/internal/test"
)

var testCtx = context.Background()

func Test_MiddlewareResponseUsesRequestModule(t *testing.T) {
mw, err := NewMiddleware(testCtx, test.BinE2EHandleResponse, handler.UnimplementedHost{})
if err != nil {
t.Fatal(err)
}
defer mw.Close(testCtx)

// A new guest module has initial state, so its value should be 42
r1Ctx, ctxNext, err := mw.HandleRequest(testCtx)
expectHandleRequest(t, mw, ctxNext, err, 42)

// The first guest shouldn't return to the pool until HandleResponse, so
// the second simultaneous call will get a new guest.
r2Ctx, ctxNext2, err := mw.HandleRequest(testCtx)
expectHandleRequest(t, mw, ctxNext2, err, 42)

// Return the first request to the pool
if err = mw.HandleResponse(r1Ctx, uint32(ctxNext>>32), nil); err != nil {
t.Fatal(err)
}
expectGlobals(t, mw, 43)

// The next request should re-use the returned module
r3Ctx, ctxNext3, err := mw.HandleRequest(testCtx)
expectHandleRequest(t, mw, ctxNext3, err, 43)
if err = mw.HandleResponse(r3Ctx, uint32(ctxNext3>>32), nil); err != nil {
t.Fatal(err)
}
expectGlobals(t, mw, 44)

// Return the second request to the pool
if err = mw.HandleResponse(r2Ctx, uint32(ctxNext2>>32), nil); err != nil {
t.Fatal(err)
}
expectGlobals(t, mw, 44, 43)
}

func expectGlobals(t *testing.T, mw Middleware, wantGlobals ...uint64) {
t.Helper()
if want, have := wantGlobals, getGlobalVals(mw); !reflect.DeepEqual(want, have) {
t.Errorf("unexpected globals, want: %v, have: %v", want, have)
}
}

func getGlobalVals(mw Middleware) []uint64 {
pool := mw.(*middleware).pool
var guests []*guest
var globals []uint64

// Take all guests out of the pool
for {
if g, ok := pool.Get().(*guest); ok {
guests = append(guests, g)
continue
}
break
}

for _, g := range guests {
v := g.guest.ExportedGlobal("reqCtx").Get()
globals = append(globals, v)
pool.Put(g)
}

return globals
}

func expectHandleRequest(t *testing.T, mw Middleware, ctxNext handler.CtxNext, err error, expectedCtx handler.CtxNext) {
t.Helper()
if err != nil {
t.Fatal(err)
}
if want, have := expectedCtx, ctxNext>>32; want != have {
t.Errorf("unexpected ctx, want: %d, have: %d", want, have)
}
if mw.(*middleware).pool.Get() != nil {
t.Error("expected handler to not return guest to the pool")
}
}
7 changes: 7 additions & 0 deletions handler/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ type requestState struct {
// features are the current request's features which may be more than
// Middleware.Features.
features handler.Features

putPool func(x any)
g *guest
}

func (r *requestState) closeRequest() (err error) {
Expand All @@ -44,6 +47,10 @@ func (r *requestState) closeRequest() (err error) {

// Close implements io.Closer
func (r *requestState) Close() (err error) {
if g := r.g; g != nil {
r.putPool(r.g)
r.g = nil
}
err = r.closeRequest()
if respBW := r.responseBodyWriter; respBW != nil {
if f, ok := respBW.(http.Flusher); ok {
Expand Down
7 changes: 6 additions & 1 deletion internal/test/testdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log"
"os"
"path"
"runtime"
)

//go:embed testdata/bench/log.wasm
Expand Down Expand Up @@ -84,7 +85,11 @@ var BinE2EHeaderNames []byte

// binExample instead of go:embed as files aren't relative to this directory.
func binExample(name string) []byte {
p := path.Join("..", "..", "examples", name+".wasm")
_, thisFile, _, ok := runtime.Caller(1)
if !ok {
log.Panicln("cannot determine current path")
}
p := path.Join(path.Dir(thisFile), "..", "..", "examples", name+".wasm")
if wasm, err := os.ReadFile(p); err != nil {
log.Panicln(err)
return nil
Expand Down
Binary file modified internal/test/testdata/e2e/handle_response.wasm
Binary file not shown.
Loading

0 comments on commit 88a400f

Please sign in to comment.