diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index a5a34d24a9..b0634d4316 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -647,6 +647,9 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider inputMessagesForChatClient.AddRange(await typedThread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false)); } + // Add the input messages before getting context from AIContextProvider. + inputMessagesForChatClient.AddRange(inputMessages); + // If we have an AIContextProvider, we should get context from it, and update our // messages and options with the additional context. if (typedThread.AIContextProvider is not null) @@ -675,9 +678,6 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider chatOptions.Instructions = string.IsNullOrWhiteSpace(chatOptions.Instructions) ? aiContext.Instructions : $"{chatOptions.Instructions}\n{aiContext.Instructions}"; } } - - // Add the input messages to the end of thread messages. - inputMessagesForChatClient.AddRange(inputMessages); } // If a user provided two different thread ids, via the thread object and options, we should throw diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index 6e9d952b57..14195eb69c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -599,13 +599,13 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() await agent.RunAsync(requestMessages, thread); // Assert - // Should contain: base instructions, context message, user message, base function, context function + // Should contain: base instructions, user message, context message, base function, context function Assert.Equal(2, capturedMessages.Count); Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions); - Assert.Equal("context provider message", capturedMessages[0].Text); - Assert.Equal(ChatRole.System, capturedMessages[0].Role); - Assert.Equal("user message", capturedMessages[1].Text); - Assert.Equal(ChatRole.User, capturedMessages[1].Role); + Assert.Equal("user message", capturedMessages[0].Text); + Assert.Equal(ChatRole.User, capturedMessages[0].Role); + Assert.Equal("context provider message", capturedMessages[1].Text); + Assert.Equal(ChatRole.System, capturedMessages[1].Role); Assert.Equal(2, capturedTools.Count); Assert.Contains(capturedTools, t => t.Name == "base function"); Assert.Contains(capturedTools, t => t.Name == "context provider function"); @@ -2056,13 +2056,13 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() _ = await updates.ToAgentRunResponseAsync(); // Assert - // Should contain: base instructions, context message, user message, base function, context function + // Should contain: base instructions, user message, context message, base function, context function Assert.Equal(2, capturedMessages.Count); Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions); - Assert.Equal("context provider message", capturedMessages[0].Text); - Assert.Equal(ChatRole.System, capturedMessages[0].Role); - Assert.Equal("user message", capturedMessages[1].Text); - Assert.Equal(ChatRole.User, capturedMessages[1].Role); + Assert.Equal("user message", capturedMessages[0].Text); + Assert.Equal(ChatRole.User, capturedMessages[0].Role); + Assert.Equal("context provider message", capturedMessages[1].Text); + Assert.Equal(ChatRole.System, capturedMessages[1].Role); Assert.Equal(2, capturedTools.Count); Assert.Contains(capturedTools, t => t.Name == "base function"); Assert.Contains(capturedTools, t => t.Name == "context provider function");