Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [Go] added middleware framework for actions and models #429

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type ModelMetadata struct {

// DefineModel registers the given generate function as an action, and returns a
// [ModelAction] that runs it.
func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *ModelAction {
func DefineModel(provider, name string, metadata *ModelMetadata, middleware []core.Middleware[*GenerateRequest, *GenerateResponse], generate func(context.Context, *GenerateRequest, ModelStreamingCallback) (*GenerateResponse, error)) *ModelAction {
metadataMap := map[string]any{}
if metadata != nil {
if metadata.Label != "" {
Expand All @@ -66,7 +66,7 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c
}
return core.DefineStreamingAction(provider, name, atype.Model, map[string]any{
"model": metadataMap,
}, generate)
}, middleware, generate)
}

// LookupModel looks up a [ModelAction] registered by [DefineModel].
Expand Down
54 changes: 46 additions & 8 deletions go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,38 @@ import (
"github.com/invopop/jsonschema"
)

// Middleware is a function that takes in an action handler function and
// returns a new handler function that might be changing input/output in
// some way.
//
// Middleware functions can:
// - execute arbitrary code;
// - change the request and response;
// - terminate response by returning a response (or error);
// - call the next middleware function.
type Middleware[I, O any] func(MiddlewareHandler[I, O]) MiddlewareHandler[I, O]

type MiddlewareHandler[I, O any] func(ctx context.Context, input I) (O, error)
pavelgj marked this conversation as resolved.
Show resolved Hide resolved

// Middlewares returns an array of middlewares that are passes in as an argument.
// core.Middlewares(apple, banana) is identical to []core.Middleware[InputType, OutputType]{apple, banana}
ianlancetaylor marked this conversation as resolved.
Show resolved Hide resolved
func Middlewares[I, O any](ms ...Middleware[I, O]) []Middleware[I, O] {
return ms
}

// ChainMiddleware creates a new Middleware that applies a sequence of
// Middlewares, so that they execute in the given order when handling action
// request.
// In other words, ChainMiddleware(m1, m2)(handler) = m1(m2(handler))
func ChainMiddleware[I, O any](middlewares ...Middleware[I, O]) Middleware[I, O] {
pavelgj marked this conversation as resolved.
Show resolved Hide resolved
return func(h MiddlewareHandler[I, O]) MiddlewareHandler[I, O] {
for i := range middlewares {
h = middlewares[len(middlewares)-1-i](h)
}
return h
}
}

// Func is the type of function that Actions and Flows execute.
// It takes an input of type Int and returns an output of type Out, optionally
// streaming values of type Stream incrementally by invoking a callback.
Expand Down Expand Up @@ -63,6 +95,7 @@ type Action[In, Out, Stream any] struct {
// optional
description string
metadata map[string]any
middleware []Middleware[In, Out]
}

// See js/core/src/action.ts
Expand All @@ -78,18 +111,18 @@ func defineAction[In, Out any](r *registry, provider, name string, atype atype.A
return a
}

func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return defineStreamingAction(globalRegistry, provider, name, atype, metadata, fn)
func DefineStreamingAction[In, Out, Stream any](provider, name string, atype atype.ActionType, metadata map[string]any, middleware []Middleware[In, Out], fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return defineStreamingAction(globalRegistry, provider, name, atype, metadata, middleware, fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still unclear about why the middlewares are getting passed all the way down. I can just apply them to the action function at the top level. I don't think anything in action.go needs to change.

}

func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := newStreamingAction(name, atype, metadata, fn)
func defineStreamingAction[In, Out, Stream any](r *registry, provider, name string, atype atype.ActionType, metadata map[string]any, middleware []Middleware[In, Out], fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
a := newStreamingAction(name, atype, metadata, middleware, fn)
r.registerAction(provider, a)
return a
}

func DefineCustomAction[In, Out, Stream any](provider, name string, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
return DefineStreamingAction(provider, name, atype.Custom, metadata, fn)
return DefineStreamingAction(provider, name, atype.Custom, metadata, nil, fn)
}

// DefineActionWithInputSchema creates a new Action and registers it.
Expand All @@ -108,13 +141,13 @@ func defineActionWithInputSchema[Out any](r *registry, provider, name string, at

// newAction creates a new Action with the given name and non-streaming function.
func newAction[In, Out any](name string, atype atype.ActionType, metadata map[string]any, fn func(context.Context, In) (Out, error)) *Action[In, Out, struct{}] {
return newStreamingAction(name, atype, metadata, func(ctx context.Context, in In, cb NoStream) (Out, error) {
return newStreamingAction(name, atype, metadata, nil, func(ctx context.Context, in In, cb NoStream) (Out, error) {
return fn(ctx, in)
})
}

// newStreamingAction creates a new Action with the given name and streaming function.
func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType, metadata map[string]any, middleware []Middleware[In, Out], fn Func[In, Out, Stream]) *Action[In, Out, Stream] {
var i In
var o Out
return &Action[In, Out, Stream]{
Expand All @@ -127,6 +160,7 @@ func newStreamingAction[In, Out, Stream any](name string, atype atype.ActionType
inputSchema: inferJSONSchema(i),
outputSchema: inferJSONSchema(o),
metadata: metadata,
middleware: middleware,
}
}

Expand Down Expand Up @@ -169,6 +203,7 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con
// This action has probably not been registered.
tstate = globalRegistry.tstate
}

return tracing.RunInNewSpan(ctx, tstate, a.name, "action", false, input,
func(ctx context.Context, input In) (Out, error) {
start := time.Now()
Expand All @@ -178,7 +213,10 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con
}
var output Out
if err == nil {
output, err = a.fn(ctx, input, cb)
dispatch := ChainMiddleware(a.middleware...)
output, err = dispatch(func(ctx context.Context, di In) (Out, error) {
return a.fn(ctx, di, cb)
})(ctx, input)
if err == nil {
if err = validateValue(output, a.outputSchema); err != nil {
err = fmt.Errorf("invalid output: %w", err)
Expand Down
58 changes: 57 additions & 1 deletion go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ func inc(_ context.Context, x int) (int, error) {
return x + 1, nil
}

func wrapRequest(next MiddlewareHandler[string, string]) MiddlewareHandler[string, string] {
return func(ctx context.Context, request string) (string, error) {
return next(ctx, "("+request+")")
}
}

func wrapResponse(next MiddlewareHandler[string, string]) MiddlewareHandler[string, string] {
return func(ctx context.Context, request string) (string, error) {
nextResponse, err := next(ctx, request)
if err != nil {
return "", err
}
return "[" + nextResponse + "]", nil
}
}

func TestActionRun(t *testing.T) {
a := newAction("inc", atype.Custom, nil, inc)
got, err := a.Run(context.Background(), 3, nil)
Expand Down Expand Up @@ -70,7 +86,7 @@ func count(ctx context.Context, n int, cb func(context.Context, int) error) (int

func TestActionStreaming(t *testing.T) {
ctx := context.Background()
a := newStreamingAction("count", atype.Custom, nil, count)
a := newStreamingAction("count", atype.Custom, nil, nil, count)
const n = 3

// Non-streaming.
Expand Down Expand Up @@ -126,3 +142,43 @@ func TestActionTracing(t *testing.T) {
}
t.Fatalf("did not find trace named %q", actionName)
}

func TestActionMiddleware(t *testing.T) {
ctx := context.Background()

sayHello := newStreamingAction("hello", atype.Custom, nil, Middlewares(wrapRequest, wrapResponse), func(ctx context.Context, input string, _ NoStream) (string, error) {
return "Hello " + input, nil
})

got, err := sayHello.Run(ctx, "Pavel", nil)
if err != nil {
t.Fatal(err)
}
want := "[Hello (Pavel)]"
if got != want {
t.Fatalf("got %v, want %v", got, want)
}
}

func TestActionInterruptedMiddleware(t *testing.T) {
ctx := context.Background()

interrupt := func(next MiddlewareHandler[string, string]) MiddlewareHandler[string, string] {
return func(ctx context.Context, request string) (string, error) {
return "interrupt (request: \"" + request + "\")", nil
}
}

a := newStreamingAction("hello", atype.Custom, nil, Middlewares(wrapRequest, interrupt, wrapResponse), func(ctx context.Context, input string, _ NoStream) (string, error) {
return "Hello " + input, nil
})

got, err := a.Run(ctx, "Pavel", nil)
if err != nil {
t.Fatal(err)
}
want := "interrupt (request: \"(Pavel)\")"
if got != want {
t.Fatalf("got %v, want %v", got, want)
}
}
2 changes: 1 addition & 1 deletion go/core/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func (f *Flow[In, Out, Stream]) action() *Action[*flowInstruction[In], *flowStat
tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true")
return f.runInstruction(ctx, inst, streamingCallback[Stream](cb))
}
return newStreamingAction(f.name, atype.Flow, metadata, cback)
return newStreamingAction(f.name, atype.Flow, metadata, nil, cback)
}

// runInstruction performs one of several actions on a flow, as determined by msg.
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/dotprompt/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func testGenerate(ctx context.Context, req *ai.GenerateRequest, cb func(context.
}

func TestExecute(t *testing.T) {
testModel := ai.DefineModel("test", "test", nil, testGenerate)
testModel := ai.DefineModel("test", "test", nil, nil, testGenerate)
p, err := New("TestExecute", "TestExecute", Config{ModelAction: testModel})
if err != nil {
t.Fatal(err)
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func defineModel(name string, client *genai.Client) {
},
}
g := generator{model: name, client: client}
ai.DefineModel(provider, name, meta, g.generate)
ai.DefineModel(provider, name, meta, nil, g.generate)
}

func defineEmbedder(name string, client *genai.Client) {
Expand Down
2 changes: 1 addition & 1 deletion go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func initModels(ctx context.Context, cfg Config) error {
},
}
g := &generator{model: name, client: gclient}
ai.DefineModel(provider, name, meta, g.generate)
ai.DefineModel(provider, name, meta, nil, g.generate)
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion js/flow/src/flow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ export function startFlowsServer(params?: {
flows.forEach((f) => {
const flowPath = `/${pathPrefix}${f.name}`;
logger.info(` - ${flowPath}`);
// Add middlware
// Add middleware
f.middleware?.forEach((m) => {
app.post(flowPath, m);
});
Expand Down
Loading