Skip to content
6 changes: 3 additions & 3 deletions dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,9 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider
inputMessagesForChatClient.AddRange(await typedThread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false));
}

// Add the input messages before getting context from AIContextProvider.
inputMessagesForChatClient.AddRange(inputMessages);

// If we have an AIContextProvider, we should get context from it, and update our
// messages and options with the additional context.
if (typedThread.AIContextProvider is not null)
Expand Down Expand Up @@ -675,9 +678,6 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider
chatOptions.Instructions = string.IsNullOrWhiteSpace(chatOptions.Instructions) ? aiContext.Instructions : $"{chatOptions.Instructions}\n{aiContext.Instructions}";
}
}

// Add the input messages to the end of thread messages.
inputMessagesForChatClient.AddRange(inputMessages);
}

// If a user provided two different thread ids, via the thread object and options, we should throw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,13 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync()
await agent.RunAsync(requestMessages, thread);

// Assert
// Should contain: base instructions, context message, user message, base function, context function
// Should contain: base instructions, user message, context message, base function, context function
Assert.Equal(2, capturedMessages.Count);
Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions);
Assert.Equal("context provider message", capturedMessages[0].Text);
Assert.Equal(ChatRole.System, capturedMessages[0].Role);
Assert.Equal("user message", capturedMessages[1].Text);
Assert.Equal(ChatRole.User, capturedMessages[1].Role);
Assert.Equal("user message", capturedMessages[0].Text);
Assert.Equal(ChatRole.User, capturedMessages[0].Role);
Assert.Equal("context provider message", capturedMessages[1].Text);
Assert.Equal(ChatRole.System, capturedMessages[1].Role);
Assert.Equal(2, capturedTools.Count);
Assert.Contains(capturedTools, t => t.Name == "base function");
Assert.Contains(capturedTools, t => t.Name == "context provider function");
Expand Down Expand Up @@ -2056,13 +2056,13 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync()
_ = await updates.ToAgentRunResponseAsync();

// Assert
// Should contain: base instructions, context message, user message, base function, context function
// Should contain: base instructions, user message, context message, base function, context function
Assert.Equal(2, capturedMessages.Count);
Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions);
Assert.Equal("context provider message", capturedMessages[0].Text);
Assert.Equal(ChatRole.System, capturedMessages[0].Role);
Assert.Equal("user message", capturedMessages[1].Text);
Assert.Equal(ChatRole.User, capturedMessages[1].Role);
Assert.Equal("user message", capturedMessages[0].Text);
Assert.Equal(ChatRole.User, capturedMessages[0].Role);
Assert.Equal("context provider message", capturedMessages[1].Text);
Assert.Equal(ChatRole.System, capturedMessages[1].Role);
Assert.Equal(2, capturedTools.Count);
Assert.Contains(capturedTools, t => t.Name == "base function");
Assert.Contains(capturedTools, t => t.Name == "context provider function");
Expand Down
Loading