Skip to content
6 changes: 3 additions & 3 deletions dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,9 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider
inputMessagesForChatClient.AddRange(await typedThread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false));
}

// Add the input messages before getting context from AIContextProvider.
inputMessagesForChatClient.AddRange(inputMessages);

// If we have an AIContextProvider, we should get context from it, and update our
// messages and options with the additional context.
if (typedThread.AIContextProvider is not null)
Expand Down Expand Up @@ -675,9 +678,6 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider
chatOptions.Instructions = string.IsNullOrWhiteSpace(chatOptions.Instructions) ? aiContext.Instructions : $"{chatOptions.Instructions}\n{aiContext.Instructions}";
}
}

// Add the input messages to the end of thread messages.
inputMessagesForChatClient.AddRange(inputMessages);
}

// If a user provided two different thread ids, via the thread object and options, we should throw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,13 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync()
await agent.RunAsync(requestMessages, thread);

// Assert
// Should contain: base instructions, context message, user message, base function, context function
// Should contain: base instructions, user message, context message, base function, context function
Assert.Equal(2, capturedMessages.Count);
Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions);
Assert.Equal("context provider message", capturedMessages[0].Text);
Assert.Equal(ChatRole.System, capturedMessages[0].Role);
Assert.Equal("user message", capturedMessages[1].Text);
Assert.Equal(ChatRole.User, capturedMessages[1].Role);
Assert.Equal("user message", capturedMessages[0].Text);
Assert.Equal(ChatRole.User, capturedMessages[0].Role);
Assert.Equal("context provider message", capturedMessages[1].Text);
Assert.Equal(ChatRole.System, capturedMessages[1].Role);
Assert.Equal(2, capturedTools.Count);
Assert.Contains(capturedTools, t => t.Name == "base function");
Assert.Contains(capturedTools, t => t.Name == "context provider function");
Expand Down Expand Up @@ -2056,13 +2056,13 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync()
_ = await updates.ToAgentRunResponseAsync();

// Assert
// Should contain: base instructions, context message, user message, base function, context function
// Should contain: base instructions, user message, context message, base function, context function
Assert.Equal(2, capturedMessages.Count);
Assert.Equal("base instructions\ncontext provider instructions", capturedInstructions);
Assert.Equal("context provider message", capturedMessages[0].Text);
Assert.Equal(ChatRole.System, capturedMessages[0].Role);
Assert.Equal("user message", capturedMessages[1].Text);
Assert.Equal(ChatRole.User, capturedMessages[1].Role);
Assert.Equal("user message", capturedMessages[0].Text);
Assert.Equal(ChatRole.User, capturedMessages[0].Role);
Assert.Equal("context provider message", capturedMessages[1].Text);
Assert.Equal(ChatRole.System, capturedMessages[1].Role);
Assert.Equal(2, capturedTools.Count);
Assert.Contains(capturedTools, t => t.Name == "base function");
Assert.Contains(capturedTools, t => t.Name == "context provider function");
Expand Down Expand Up @@ -2129,6 +2129,152 @@ await Assert.ThrowsAsync<InvalidOperationException>(async () =>
x.InvokeException is InvalidOperationException), It.IsAny<CancellationToken>()), Times.Once);
}

/// <summary>
/// Verify that messages are stored in MessageStore with AIContextProvider messages.
/// Order stored should be: Input messages, AIContextProvider messages, Response messages.
/// </summary>
[Fact]
public async Task VerifyMessageOrderingWithAIContextProviderAsync()
{
// Arrange
var existingMessages = new List<ChatMessage>
{
new(ChatRole.User, "Message A"),
new(ChatRole.Assistant, "Message B")
};

var inputMessage = new ChatMessage(ChatRole.User, "Message C");
var aiContextProviderMessage = new ChatMessage(ChatRole.System, "Message X");
var responseMessage = new ChatMessage(ChatRole.Assistant, "Message D");

List<ChatMessage>? messagesToChatClient = null;
Mock<IChatClient> mockService = new();
mockService
.Setup(s => s.GetResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Callback<IEnumerable<ChatMessage>, ChatOptions, CancellationToken>((msgs, opts, ct) => messagesToChatClient = msgs.ToList())
.ReturnsAsync(new ChatResponse([responseMessage]));

var mockContextProvider = new Mock<AIContextProvider>();
mockContextProvider
.Setup(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new AIContext
{
Messages = [aiContextProviderMessage],
});

var messageStore = new InMemoryChatMessageStore();
await messageStore.AddMessagesAsync(existingMessages);

ChatClientAgent agent = new(mockService.Object, options: new()
{
ChatOptions = new() { Instructions = "test instructions" },
AIContextProviderFactory = _ => mockContextProvider.Object
});

var thread = new ChatClientAgentThread
{
MessageStore = messageStore,
AIContextProvider = mockContextProvider.Object
};

// Act
await agent.RunAsync([inputMessage], thread);

// Assert - Verify order sent to chat client: [Existing, Input, AIContextProvider]
Assert.NotNull(messagesToChatClient);
Assert.Equal(4, messagesToChatClient.Count);
Assert.Equal("Message A", messagesToChatClient[0].Text);
Assert.Equal("Message B", messagesToChatClient[1].Text);
Assert.Equal("Message C", messagesToChatClient[2].Text);
Assert.Equal("Message X", messagesToChatClient[3].Text);

// Assert - Verify order stored in MessageStore: [Existing, Input, AIContextProvider, Response]
var storedMessagesList = (await messageStore.GetMessagesAsync()).ToList();
Assert.Equal(5, storedMessagesList.Count);
Assert.Equal("Message A", storedMessagesList[0].Text);
Assert.Equal("Message B", storedMessagesList[1].Text);
Assert.Equal("Message C", storedMessagesList[2].Text);
Assert.Equal("Message X", storedMessagesList[3].Text);
Assert.Equal("Message D", storedMessagesList[4].Text);
}

/// <summary>
/// Verify that messages are stored in MessageStore with AIContextProvider messages (streaming version).
/// Order stored should be: Input messages, AIContextProvider messages, Response messages.
/// </summary>
[Fact]
public async Task VerifyMessageOrderingWithAIContextProviderStreamingAsync()
{
// Arrange
var existingMessages = new List<ChatMessage>
{
new(ChatRole.User, "Message A"),
new(ChatRole.Assistant, "Message B")
};

var inputMessage = new ChatMessage(ChatRole.User, "Message C");
var aiContextProviderMessage = new ChatMessage(ChatRole.System, "Message X");
ChatResponseUpdate responseUpdate = new() { Role = ChatRole.Assistant };
responseUpdate.Contents.Add(new TextContent("Message D"));

List<ChatMessage>? messagesToChatClient = null;
Mock<IChatClient> mockService = new();
mockService
.Setup(s => s.GetStreamingResponseAsync(
It.IsAny<IEnumerable<ChatMessage>>(),
It.IsAny<ChatOptions>(),
It.IsAny<CancellationToken>()))
.Callback<IEnumerable<ChatMessage>, ChatOptions, CancellationToken>((msgs, opts, ct) => messagesToChatClient = msgs.ToList())
.Returns(ToAsyncEnumerableAsync([responseUpdate]));

var mockContextProvider = new Mock<AIContextProvider>();
mockContextProvider
.Setup(p => p.InvokingAsync(It.IsAny<AIContextProvider.InvokingContext>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(new AIContext
{
Messages = [aiContextProviderMessage],
});

var messageStore = new InMemoryChatMessageStore();
await messageStore.AddMessagesAsync(existingMessages);

ChatClientAgent agent = new(mockService.Object, options: new()
{
ChatOptions = new() { Instructions = "test instructions" },
AIContextProviderFactory = _ => mockContextProvider.Object
});

var thread = new ChatClientAgentThread
{
MessageStore = messageStore,
AIContextProvider = mockContextProvider.Object
};

// Act
var updates = agent.RunStreamingAsync([inputMessage], thread);
await updates.ToAgentRunResponseAsync();

// Assert - Verify order sent to chat client: [Existing, Input, AIContextProvider]
Assert.NotNull(messagesToChatClient);
Assert.Equal(4, messagesToChatClient.Count);
Assert.Equal("Message A", messagesToChatClient[0].Text);
Assert.Equal("Message B", messagesToChatClient[1].Text);
Assert.Equal("Message C", messagesToChatClient[2].Text);
Assert.Equal("Message X", messagesToChatClient[3].Text);

// Assert - Verify order stored in MessageStore: [Existing, Input, AIContextProvider, Response]
var storedMessagesList = (await messageStore.GetMessagesAsync()).ToList();
Assert.Equal(5, storedMessagesList.Count);
Assert.Equal("Message A", storedMessagesList[0].Text);
Assert.Equal("Message B", storedMessagesList[1].Text);
Assert.Equal("Message C", storedMessagesList[2].Text);
Assert.Equal("Message X", storedMessagesList[3].Text);
Assert.Equal("Message D", storedMessagesList[4].Text);
}

#endregion

private static async IAsyncEnumerable<T> ToAsyncEnumerableAsync<T>(IEnumerable<T> values)
Expand Down