Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/ModelContextProtocol.Core/AIContentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ public static class AIContentExtensions
{
if (sm.Content?.Select(b => b.ToAIContent()).OfType<AIContent>().ToList() is { Count: > 0 } aiContents)
{
messages.Add(new ChatMessage(sm.Role is Role.Assistant ? ChatRole.Assistant : ChatRole.User, aiContents));
ChatRole role =
aiContents.All(static c => c is FunctionResultContent) ? ChatRole.Tool :
sm.Role is Role.Assistant ? ChatRole.Assistant :
ChatRole.User;
messages.Add(new ChatMessage(role, aiContents));
}
}

Expand Down
99 changes: 98 additions & 1 deletion tests/ModelContextProtocol.Tests/Client/McpClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,103 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages()
Assert.Equal("endTurn", result.StopReason);
}

[Fact]
public async Task CreateSamplingHandler_ShouldUseToolRoleForToolResultMessages()
{
// Arrange
var mockChatClient = new Mock<IChatClient>();
var requestParams = new CreateMessageRequestParams
{
Messages =
[
new SamplingMessage
{
Role = Role.User,
Content = [new TextContentBlock { Text = "What is the weather in Paris?" }]
},
new SamplingMessage
{
Role = Role.Assistant,
Content = [new ToolUseContentBlock
{
Id = "call_weather_123",
Name = "get_weather",
Input = JsonElement.Parse("""{"location":"Paris"}""")
}]
},
new SamplingMessage
{
Role = Role.User,
Content = [new ToolResultContentBlock
{
ToolUseId = "call_weather_123",
Content = [new TextContentBlock { Text = "Weather: 18°C, sunny" }]
}]
},
new SamplingMessage
{
Role = Role.User,
Content =
[
new ToolResultContentBlock
{
ToolUseId = "call_mixed_123",
Content = [new TextContentBlock { Text = "Tool result" }]
},
new TextContentBlock { Text = "Additional text content" }
]
}
],
MaxTokens = 100
};

IEnumerable<ChatMessage>? capturedMessages = null;
var cancellationToken = CancellationToken.None;
var expectedResponse = new[] {
new ChatResponseUpdate
{
ModelId = "test-model",
FinishReason = ChatFinishReason.Stop,
Role = ChatRole.Assistant,
Contents = [new TextContent("The weather in Paris is 18°C and sunny.")]
}
}.ToAsyncEnumerable();

mockChatClient
.Setup(client => client.GetStreamingResponseAsync(It.IsAny<IEnumerable<ChatMessage>>(), It.IsAny<ChatOptions>(), cancellationToken))
.Callback<IEnumerable<ChatMessage>, ChatOptions?, CancellationToken>((messages, _, _) => capturedMessages = messages.ToList())
.Returns(expectedResponse);

var handler = mockChatClient.Object.CreateSamplingHandler();

// Act
var result = await handler(requestParams, Mock.Of<IProgress<ProgressNotificationValue>>(), cancellationToken);

// Assert
Assert.NotNull(result);
Assert.NotNull(capturedMessages);
var messagesList = capturedMessages.ToList();
Assert.Equal(4, messagesList.Count);

// First message should be User role (text message)
Assert.Equal(ChatRole.User, messagesList[0].Role);
Assert.IsType<TextContent>(messagesList[0].Contents.Single());

// Second message should be Assistant role (tool use)
Assert.Equal(ChatRole.Assistant, messagesList[1].Role);
Assert.IsType<FunctionCallContent>(messagesList[1].Contents.Single());

// Third message should be Tool role (tool result only) - this is the bug fix
Assert.Equal(ChatRole.Tool, messagesList[2].Role);
Assert.IsType<FunctionResultContent>(messagesList[2].Contents.Single());

// Fourth message should be User role (mixed content: tool result + text)
Assert.Equal(ChatRole.User, messagesList[3].Role);
Assert.Equal(2, messagesList[3].Contents.Count);
Assert.Contains(messagesList[3].Contents, c => c is FunctionResultContent);
Assert.Contains(messagesList[3].Contents, c => c is TextContent);
}

[Fact]
public async Task ListToolsAsync_AllToolsReturned()
{
Expand Down Expand Up @@ -690,4 +787,4 @@ public async Task SetLoggingLevelAsync_WithRequestParams_NullThrows()
await Assert.ThrowsAsync<ArgumentNullException>("requestParams",
() => client.SetLoggingLevelAsync((SetLevelRequestParams)null!, TestContext.Current.CancellationToken));
}
}
}