Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ModelContextProtocol.Core/AIContentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ public static class AIContentExtensions
{
if (sm.Content?.Select(b => b.ToAIContent()).OfType<AIContent>().ToList() is { Count: > 0 } aiContents)
{
messages.Add(new ChatMessage(sm.Role is Role.Assistant ? ChatRole.Assistant : ChatRole.User, aiContents));
ChatRole role = aiContents.All(c => c is FunctionResultContent) ? ChatRole.Tool :
sm.Role is Role.Assistant ? ChatRole.Assistant :
ChatRole.User;
messages.Add(new ChatMessage(role, aiContents));
}
}

Expand Down
78 changes: 78 additions & 0 deletions tests/ModelContextProtocol.Tests/Client/McpClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,84 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages()
Assert.Equal("endTurn", result.StopReason);
}

[Fact]
public async Task CreateSamplingHandler_ShouldUseToolRoleForToolResultMessages()
{
// Arrange
var mockChatClient = new Mock<IChatClient>();
var requestParams = new CreateMessageRequestParams
{
Messages =
[
new SamplingMessage
{
Role = Role.User,
Content = [new TextContentBlock { Text = "What is the weather in Paris?" }]
},
new SamplingMessage
{
Role = Role.Assistant,
Content = [new ToolUseContentBlock
{
Id = "call_weather_123",
Name = "get_weather",
Input = JsonElement.Parse("""{"location":"Paris"}""")
}]
},
new SamplingMessage
{
Role = Role.User,
Content = [new ToolResultContentBlock
{
ToolUseId = "call_weather_123",
Content = [new TextContentBlock { Text = "Weather: 18°C, sunny" }]
}]
}
],
MaxTokens = 100
};

IEnumerable<ChatMessage>? capturedMessages = null;
var cancellationToken = CancellationToken.None;
var expectedResponse = new[] {
new ChatResponseUpdate
{
ModelId = "test-model",
FinishReason = ChatFinishReason.Stop,
Role = ChatRole.Assistant,
Contents = [new TextContent("The weather in Paris is 18°C and sunny.")]
}
}.ToAsyncEnumerable();

mockChatClient
.Setup(client => client.GetStreamingResponseAsync(It.IsAny<IEnumerable<ChatMessage>>(), It.IsAny<ChatOptions>(), cancellationToken))
.Callback<IEnumerable<ChatMessage>, ChatOptions?, CancellationToken>((messages, _, _) => capturedMessages = messages.ToList())
.Returns(expectedResponse);

var handler = mockChatClient.Object.CreateSamplingHandler();

// Act
var result = await handler(requestParams, Mock.Of<IProgress<ProgressNotificationValue>>(), cancellationToken);

// Assert
Assert.NotNull(result);
Assert.NotNull(capturedMessages);
var messagesList = capturedMessages.ToList();
Assert.Equal(3, messagesList.Count);

// First message should be User role (text message)
Assert.Equal(ChatRole.User, messagesList[0].Role);
Assert.IsType<TextContent>(messagesList[0].Contents.Single());

// Second message should be Assistant role (tool use)
Assert.Equal(ChatRole.Assistant, messagesList[1].Role);
Assert.IsType<FunctionCallContent>(messagesList[1].Contents.Single());

// Third message should be Tool role (tool result) - this is the bug fix
Assert.Equal(ChatRole.Tool, messagesList[2].Role);
Assert.IsType<FunctionResultContent>(messagesList[2].Contents.Single());
}

[Fact]
public async Task ListToolsAsync_AllToolsReturned()
{
Expand Down
Loading