Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions agent/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright 2026 Redpanda Data, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package agent

import "context"

type contextKey string

const (
globalInstructionsKey contextKey = "gen_ai.global_instructions"
)

// ContextWithGlobalInstructions returns a new context with the provided global instructions.
//
// Global instructions are system-wide directives that apply to all agents in an
// invocation tree. When an agent is called (directly or as a tool), these
// instructions are appended to its base system prompt.
//
// Use this to propagate cross-cutting constraints like "Always respond in JSON"
// or "Be extremely concise" across hierarchical agent calls.
func ContextWithGlobalInstructions(ctx context.Context, instructions string) context.Context {
return context.WithValue(ctx, globalInstructionsKey, instructions)
}

// GlobalInstructions retrieves the global instructions from the context.
// Returns an empty string if no instructions are set.
func GlobalInstructions(ctx context.Context) string {
val, ok := ctx.Value(globalInstructionsKey).(string)
if !ok {
return ""
}
return val
}
83 changes: 83 additions & 0 deletions agent/llmagent/agenttool_propagation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2026 Redpanda Data, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package llmagent_test

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/redpanda-data/ai-sdk-go/agent"
"github.com/redpanda-data/ai-sdk-go/agent/llmagent"
"github.com/redpanda-data/ai-sdk-go/llm"
"github.com/redpanda-data/ai-sdk-go/llm/fakellm"
"github.com/redpanda-data/ai-sdk-go/store/session"
"github.com/redpanda-data/ai-sdk-go/tool"
"github.com/redpanda-data/ai-sdk-go/tool/agenttool"
)

func TestGlobalInstructionPropagation(t *testing.T) {
t.Parallel()

ctx := context.Background()

// 1. Setup Child Agent
childFake := fakellm.NewFakeModel()
childFake.When(fakellm.Any()).ThenRespondText("Child response")

childAgent, err := llmagent.New("child", "I am the child.", childFake)
require.NoError(t, err)

// 2. Setup Parent Agent with Child as a Tool
registry := tool.NewRegistry(tool.RegistryConfig{})
require.NoError(t, registry.Register(agenttool.New(childAgent)))

parentFake := fakellm.NewFakeModel()
// Rule to make the parent call the child tool
parentFake.When(fakellm.UserMessageContains("delegate")).
ThenRespondWithToolCall("child", map[string]any{"task": "do something"})
// Rule for the second turn (after tool result)
parentFake.When(fakellm.Any()).ThenRespondText("Parent done")

parentAgent, err := llmagent.New("parent", "I am the parent.", parentFake, llmagent.WithTools(registry))
require.NoError(t, err)

// 3. Execute with Global Instructions in Context
gctx := agent.ContextWithGlobalInstructions(ctx, "CRITICAL: ALWAYS USE JSON.")

sess := &session.State{ID: "parent-sess"}
sess.Messages = append(sess.Messages, llm.NewMessage(llm.RoleUser, llm.NewTextPart("delegate to child")))
inv := agent.NewInvocationMetadata(sess, parentAgent.Info())

for evt, err := range parentAgent.Run(gctx, inv) {
require.NoError(t, err)
_ = evt
}

// 4. Verify Parent Prompt
parentCalls := parentFake.Calls()
require.NotEmpty(t, parentCalls)
parentSystemMsg := findSystemMessage(parentCalls[0].Request)
assert.Contains(t, parentSystemMsg.TextContent(), "CRITICAL: ALWAYS USE JSON.")

// 5. Verify Child Prompt (Propagation check)
childCalls := childFake.Calls()
require.NotEmpty(t, childCalls, "Child agent should have been called as a tool")
childSystemMsg := findSystemMessage(childCalls[0].Request)
assert.Contains(t, childSystemMsg.TextContent(), "CRITICAL: ALWAYS USE JSON.",
"Child agent should have inherited global instructions from parent context")
}
55 changes: 34 additions & 21 deletions agent/llmagent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,37 @@ import (
"github.com/redpanda-data/ai-sdk-go/tool"
)

// SystemPromptProvider is a function that returns the system prompt for a
// InstructionProvider is a function that returns the system prompt for a
// given request. It is called once per LLM call (i.e., every turn in the
// agentic loop), receiving both the request context and the invocation
// metadata so callers can draw from either source:
//
// - ctx carries request-scoped values (e.g., authenticated identity
// injected by HTTP middleware via [context.WithValue]).
// - inv exposes session metadata, per-invocation metadata set by
// interceptors, and the current turn number.
// - inv.Session().Metadata contains long-lived session state, which is the
// preferred source for persistent templating variables (e.g., user name,
// preferences).
// - inv.Metadata() contains transient metadata primarily used for
// interceptor communication during the current invocation.
//
// Use [WithSystemPromptProvider] to configure it. When set, it takes
// Use [WithInstructionProvider] to configure it. When set, it takes
// precedence over the static systemPrompt string.
type SystemPromptProvider func(ctx context.Context, inv *agent.InvocationMetadata) (string, error)
type InstructionProvider func(ctx context.Context, inv *agent.InvocationMetadata) (string, error)

// config holds the internal configuration for an LLMAgent.
type config struct {
name string
description string
systemPrompt string
systemPromptProvider SystemPromptProvider
id string
version string
model llm.Model
tools tool.Registry
interceptors []agent.Interceptor
maxTurns int
toolConcurrency int
name string
description string
systemPrompt string
instructionProvider InstructionProvider
globalInstruction string
id string
version string
model llm.Model
tools tool.Registry
interceptors []agent.Interceptor
maxTurns int
toolConcurrency int
}

// validate checks that the configuration is valid.
Expand All @@ -59,8 +63,8 @@ func (c *config) validate() error {
return errors.New("llmagent: name is required")
}

if c.systemPrompt == "" && c.systemPromptProvider == nil {
return errors.New("llmagent: system prompt is required (set either systemPrompt or SystemPromptProvider)")
if c.systemPrompt == "" && c.instructionProvider == nil {
return errors.New("llmagent: system prompt is required (set either systemPrompt or InstructionProvider)")
}

if c.model == nil {
Expand Down Expand Up @@ -88,7 +92,7 @@ func (c *config) validate() error {
// Option configures an LLMAgent.
type Option func(*config)

// WithSystemPromptProvider sets a dynamic system prompt provider.
// WithInstructionProvider sets a dynamic system prompt provider.
//
// When set, the provider is called every turn to produce the system prompt,
// and the static systemPrompt argument to [New] is ignored. Pass an empty
Expand All @@ -97,9 +101,9 @@ type Option func(*config)
// The provider receives both context.Context (for request-scoped values like
// authenticated identity) and [agent.InvocationMetadata] (for session state,
// interceptor metadata, and turn number).
func WithSystemPromptProvider(p SystemPromptProvider) Option {
func WithInstructionProvider(p InstructionProvider) Option {
return func(c *config) {
c.systemPromptProvider = p
c.instructionProvider = p
}
}

Expand Down Expand Up @@ -149,6 +153,15 @@ func WithID(id string) Option {
}
}

// WithGlobalInstruction sets a static global instruction that applies to
// all agents in a multi-agent tree. It is appended to the system prompt
// along with any instructions found in the context.
func WithGlobalInstruction(instr string) Option {
return func(c *config) {
c.globalInstruction = instr
}
}

// WithVersion sets the agent's version (used for gen_ai.agent.version).
func WithVersion(version string) Option {
return func(c *config) {
Expand Down
149 changes: 149 additions & 0 deletions agent/llmagent/dynamic_prompt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright 2026 Redpanda Data, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package llmagent_test

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/redpanda-data/ai-sdk-go/agent"
"github.com/redpanda-data/ai-sdk-go/agent/llmagent"
"github.com/redpanda-data/ai-sdk-go/llm"
"github.com/redpanda-data/ai-sdk-go/llm/fakellm"
"github.com/redpanda-data/ai-sdk-go/store/session"
)

func TestDynamicInstructionProvider(t *testing.T) {
t.Parallel()

ctx := context.Background()

// Setup fake model to capture the request
fake := fakellm.NewFakeModel()
fake.When(fakellm.Any()).ThenRespondText("Done")

// Create agent with InstructionProvider
a, err := llmagent.New("test-agent", "Fallback prompt", fake,
llmagent.WithInstructionProvider(func(ctx context.Context, inv *agent.InvocationMetadata) (string, error) {
user := "Unknown"
if v, ok := inv.Metadata()["user_name"]; ok {
user = v.(string)
} else if sess := inv.Session(); sess != nil && sess.Metadata != nil {
if v, ok := sess.Metadata["user_name"]; ok {
user = v.(string)
}
}

role := "Assistant"
if v, ok := inv.Metadata()["role"]; ok {
role = v.(string)
}

date := time.Now().UTC().Format("2006-01-02")
return fmt.Sprintf("Hello %s! Today is %s. Your role is %s.", user, date, role), nil
}),
)
require.NoError(t, err)

// Case 1: Session Metadata
t.Run("Session Metadata", func(t *testing.T) {
sess := &session.State{
ID: "test-sess",
Metadata: map[string]any{
"user_name": "Alice",
},
}
sess.Messages = append(sess.Messages, llm.NewMessage(llm.RoleUser, llm.NewTextPart("hi")))
inv := agent.NewInvocationMetadata(sess, a.Info())

for evt, err := range a.Run(ctx, inv) {
require.NoError(t, err)
_ = evt
}

reqs := fake.Calls()
require.Len(t, reqs, 1)

systemMsg := findSystemMessage(reqs[0].Request)
require.NotNil(t, systemMsg)

expectedDate := time.Now().UTC().Format("2006-01-02")
assert.Contains(t, systemMsg.TextContent(), "Hello Alice!")
assert.Contains(t, systemMsg.TextContent(), "Today is "+expectedDate)
assert.Contains(t, systemMsg.TextContent(), "Your role is Assistant")
})

// Case 2: Invocation Metadata Overrides Session
t.Run("Invocation Override", func(t *testing.T) {
fake.ResetCalls()
sess := &session.State{
ID: "test-sess",
Metadata: map[string]any{
"user_name": "Alice",
},
}
sess.Messages = append(sess.Messages, llm.NewMessage(llm.RoleUser, llm.NewTextPart("hi")))
inv := agent.NewInvocationMetadata(sess, a.Info())
inv.SetMetadata("user_name", "Bob")
inv.SetMetadata("role", "Expert")

for evt, err := range a.Run(ctx, inv) {
require.NoError(t, err)
_ = evt
}

reqs := fake.Calls()
systemMsg := findSystemMessage(reqs[0].Request)
assert.Contains(t, systemMsg.TextContent(), "Hello Bob!")
assert.Contains(t, systemMsg.TextContent(), "Your role is Expert")
})

// Case 3: Global Instructions from Context
t.Run("Global Instructions", func(t *testing.T) {
fake.ResetCalls()
sess := &session.State{ID: "test-sess"}
sess.Messages = append(sess.Messages, llm.NewMessage(llm.RoleUser, llm.NewTextPart("hi")))
inv := agent.NewInvocationMetadata(sess, a.Info())

// Add global instructions to context
gctx := agent.ContextWithGlobalInstructions(ctx, "Be extremely polite.")

for evt, err := range a.Run(gctx, inv) {
require.NoError(t, err)
_ = evt
}

reqs := fake.Calls()
systemMsg := findSystemMessage(reqs[0].Request)

assert.Contains(t, systemMsg.TextContent(), "---")
assert.Contains(t, systemMsg.TextContent(), "## Global Instructions")
assert.Contains(t, systemMsg.TextContent(), "Be extremely polite.")
})
}

func findSystemMessage(req *llm.Request) *llm.Message {
for _, m := range req.Messages {
if m.Role == llm.RoleSystem {
return &m
}
}
return nil
}
Loading