From e3beb4ded2e146f8807be6b0d0de38870841d380 Mon Sep 17 00:00:00 2001 From: Alex Acebo Date: Tue, 7 May 2024 13:31:55 -0400 Subject: [PATCH] [C#] feat: change `PredictedSayCommand` to include response `ChatMessage` (#1610) ## Linked issues closes: #minor ## Details update `PredictedSayCommand` to have a `ChatMessage` response instead of a `string` type so that we can access citations and other message data from action handlers. --- .../AITests/AITests.cs | 2 +- .../AITests/AssistantsPlannerTests.cs | 16 ++++----- .../MonologueAugmentationTests.cs | 16 +++++++-- .../SequenceAugmentationTests.cs | 16 +++++++-- .../AI/Action/DefaultActions.cs | 4 +-- .../AI/Augmentations/MonologueAugmentation.cs | 8 ++++- .../AI/Augmentations/SequenceAugmentation.cs | 18 ++++++++++ .../AI/Models/MessageContext.cs | 4 +-- .../Moderator/AzureContentSafetyModerator.cs | 2 +- .../AI/Moderator/OpenAIModerator.cs | 2 +- .../AI/Planners/PredictedSayCommand.cs | 15 +++++++-- .../ChatMessageJsonConverter.cs | 33 +++++++++++++++++++ .../JsonConverters/CommandJsonConverter.cs | 2 +- 13 files changed, 115 insertions(+), 23 deletions(-) create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Utilities/JsonConverters/ChatMessageJsonConverter.cs diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/AITests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/AITests.cs index dc8e1bdcb..77483a81a 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/AITests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/AITests.cs @@ -197,7 +197,7 @@ public string DoCommand([ActionName] string action) [Action(AIConstants.SayCommandActionName)] public string SayCommand([ActionParameters] PredictedSayCommand command) { - SayActionRecord.Add(command.Response); + SayActionRecord.Add(command.Response.Content); return string.Empty; } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/AssistantsPlannerTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/AssistantsPlannerTests.cs index 69f5074bf..f006bd1cb 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/AssistantsPlannerTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/AssistantsPlannerTests.cs @@ -42,7 +42,7 @@ public async Task Test_BeginTaskAsync_Assistant_Single_Reply() Assert.NotNull(plan.Commands); Assert.Single(plan.Commands); Assert.Equal(AIConstants.SayCommand, plan.Commands[0].Type); - Assert.Equal("welcome", ((PredictedSayCommand)plan.Commands[0]).Response); + Assert.Equal("welcome", ((PredictedSayCommand)plan.Commands[0]).Response.Content); } [Fact] @@ -74,7 +74,7 @@ public async Task Test_BeginTaskAsync_Assistant_WaitForCurrentRun() Assert.NotNull(plan.Commands); Assert.Single(plan.Commands); Assert.Equal(AIConstants.SayCommand, plan.Commands[0].Type); - Assert.Equal("welcome", ((PredictedSayCommand)plan.Commands[0]).Response); + Assert.Equal("welcome", ((PredictedSayCommand)plan.Commands[0]).Response.Content); } [Fact] @@ -110,7 +110,7 @@ public async Task Test_BeginTaskAsync_Assistant_WaitForPreviousRun() Assert.NotNull(plan.Commands); Assert.Single(plan.Commands); Assert.Equal(AIConstants.SayCommand, plan.Commands[0].Type); - Assert.Equal("welcome", ((PredictedSayCommand)plan.Commands[0]).Response); + Assert.Equal("welcome", ((PredictedSayCommand)plan.Commands[0]).Response.Content); } [Fact] @@ -243,7 +243,7 @@ public async Task Test_ContinueTaskAsync_Assistant_RequiresAction() Assert.NotNull(plan2.Commands); Assert.Single(plan2.Commands); Assert.Equal(AIConstants.SayCommand, plan2.Commands[0].Type); - Assert.Equal("welcome", ((PredictedSayCommand)plan2.Commands[0]).Response); + Assert.Equal("welcome", ((PredictedSayCommand)plan2.Commands[0]).Response.Content); Assert.Single(turnState.SubmitToolMap); Assert.Equal("test-action", turnState.SubmitToolMap.First().Key); Assert.Equal("test-tool-id", turnState.SubmitToolMap.First().Value); @@ -291,7 +291,7 @@ public async Task Test_ContinueTaskAsync_Assistant_IgnoreRedundantAction() Assert.NotNull(plan2.Commands); Assert.Single(plan2.Commands); Assert.Equal(AIConstants.SayCommand, plan2.Commands[0].Type); - Assert.Equal("welcome", ((PredictedSayCommand)plan2.Commands[0]).Response); + Assert.Equal("welcome", ((PredictedSayCommand)plan2.Commands[0]).Response.Content); Assert.Single(turnState.SubmitToolMap); Assert.Equal("test-action", turnState.SubmitToolMap.First().Key); Assert.Equal("test-tool-id", turnState.SubmitToolMap.First().Value); @@ -328,9 +328,9 @@ public async Task Test_ContinueTaskAsync_Assistant_MultipleMessages() Assert.NotNull(plan.Commands); Assert.Equal(3, plan.Commands.Count); Assert.Equal(AIConstants.SayCommand, plan.Commands[0].Type); - Assert.Equal("welcome", ((PredictedSayCommand)plan.Commands[0]).Response); - Assert.Equal("message 1", ((PredictedSayCommand)plan.Commands[1]).Response); - Assert.Equal("message 2", ((PredictedSayCommand)plan.Commands[2]).Response); + Assert.Equal("welcome", ((PredictedSayCommand)plan.Commands[0]).Response.Content); + Assert.Equal("message 1", ((PredictedSayCommand)plan.Commands[1]).Response.Content); + Assert.Equal("message 2", ((PredictedSayCommand)plan.Commands[2]).Response.Content); } private static async Task _CreateAssistantsState() diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Augmentations/MonologueAugmentationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Augmentations/MonologueAugmentationTests.cs index 939aafe6c..390b510be 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Augmentations/MonologueAugmentationTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Augmentations/MonologueAugmentationTests.cs @@ -254,7 +254,15 @@ public async void Test_CreatePlanFromResponseAsync_SayCommand_ShouldSucceed() Status = PromptResponseStatus.Success, Message = new(ChatRole.Assistant) { - Content = JsonSerializer.Serialize(monologue) + Content = JsonSerializer.Serialize(monologue), + Context = new() + { + Intent = "test intent", + Citations = new List + { + new("content", "title", "url") + } + } } }; @@ -267,7 +275,11 @@ public async void Test_CreatePlanFromResponseAsync_SayCommand_ShouldSucceed() Assert.NotNull(plan); Assert.Equal(1, plan.Commands.Count); Assert.Equal("SAY", plan.Commands[0].Type); - Assert.Equal("hello world", (plan.Commands[0] as PredictedSayCommand)?.Response); + Assert.Equal("hello world", (plan.Commands[0] as PredictedSayCommand)?.Response.Content); + Assert.Equal("test intent", (plan.Commands[0] as PredictedSayCommand)?.Response.Context?.Intent); + Assert.Equal("content", (plan.Commands[0] as PredictedSayCommand)?.Response.Context?.Citations[0].Content); + Assert.Equal("title", (plan.Commands[0] as PredictedSayCommand)?.Response.Context?.Citations[0].Title); + Assert.Equal("url", (plan.Commands[0] as PredictedSayCommand)?.Response.Context?.Citations[0].Url); } [Fact] diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Augmentations/SequenceAugmentationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Augmentations/SequenceAugmentationTests.cs index 41eff6b16..7ff3eef23 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Augmentations/SequenceAugmentationTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Augmentations/SequenceAugmentationTests.cs @@ -49,7 +49,15 @@ public async Task Test_CreatePlanFromResponseAsync_ValidPlan_ShouldSucceed() ""response"": ""hello"" } ] -}" +}", + Context = new() + { + Intent = "test intent", + Citations = new List + { + new("content", "title", "url") + } + } } }; @@ -62,7 +70,11 @@ public async Task Test_CreatePlanFromResponseAsync_ValidPlan_ShouldSucceed() Assert.Equal("DO", plan.Commands[0].Type); Assert.Equal("test", (plan.Commands[0] as PredictedDoCommand)?.Action); Assert.Equal("SAY", plan.Commands[1].Type); - Assert.Equal("hello", (plan.Commands[1] as PredictedSayCommand)?.Response); + Assert.Equal("hello", (plan.Commands[1] as PredictedSayCommand)?.Response.Content); + Assert.Equal("test intent", (plan.Commands[1] as PredictedSayCommand)?.Response.Context?.Intent); + Assert.Equal("content", (plan.Commands[1] as PredictedSayCommand)?.Response.Context?.Citations[0].Content); + Assert.Equal("title", (plan.Commands[1] as PredictedSayCommand)?.Response.Context?.Citations[0].Title); + Assert.Equal("url", (plan.Commands[1] as PredictedSayCommand)?.Response.Context?.Citations[0].Url); } [Fact] diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Action/DefaultActions.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Action/DefaultActions.cs index 6f8eca717..3386231e4 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Action/DefaultActions.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Action/DefaultActions.cs @@ -81,11 +81,11 @@ public async Task SayCommandAsync([ActionTurnContext] ITurnContext turnC if (turnContext.Activity.ChannelId == Channels.Msteams) { - await turnContext.SendActivityAsync(command.Response.Replace("\n", "
"), null, null, cancellationToken); + await turnContext.SendActivityAsync(command.Response.Content.Replace("\n", "
"), null, null, cancellationToken); } else { - await turnContext.SendActivityAsync(command.Response, null, null, cancellationToken); + await turnContext.SendActivityAsync(command.Response.Content, null, null, cancellationToken); }; return string.Empty; diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Augmentations/MonologueAugmentation.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Augmentations/MonologueAugmentation.cs index 1d149313b..a4fd54a58 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Augmentations/MonologueAugmentation.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Augmentations/MonologueAugmentation.cs @@ -260,7 +260,13 @@ public MonologueAugmentation(List actions) } } - command = new PredictedSayCommand(text); + ChatMessage message = response.Message ?? new ChatMessage(ChatRole.Assistant) + { + Context = response.Message?.Context, + }; + + message.Content = text; + command = new PredictedSayCommand(message); } else { diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Augmentations/SequenceAugmentation.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Augmentations/SequenceAugmentation.cs index db9e8a45a..b20b395d8 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Augmentations/SequenceAugmentation.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Augmentations/SequenceAugmentation.cs @@ -48,6 +48,24 @@ public SequenceAugmentation(List actions) try { Plan? plan = JsonSerializer.Deserialize(response.Message?.Content ?? ""); + + if (plan != null) + { + foreach (IPredictedCommand cmd in plan.Commands) + { + if (cmd is PredictedSayCommand say) + { + ChatMessage message = response.Message ?? new ChatMessage(ChatRole.Assistant) + { + Context = response.Message?.Context, + }; + + message.Content = say.Response.Content; + say.Response = message; + } + } + } + return await Task.FromResult(plan); } catch (Exception) diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/MessageContext.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/MessageContext.cs index 45f76d2c9..33c83c771 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/MessageContext.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/MessageContext.cs @@ -10,7 +10,7 @@ public class MessageContext /// /// Citations used in the message. /// - public IList Citations { get; } = new List(); + public IList Citations { get; set; } = new List(); /// /// The intent of the message. @@ -31,7 +31,7 @@ public class Citation /// /// The title of the citation. /// - public string Title { get; set; } + public string Title { get; set; } /// /// The URL of the citation. diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Moderator/AzureContentSafetyModerator.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Moderator/AzureContentSafetyModerator.cs index 1e956023a..ff65d70c1 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Moderator/AzureContentSafetyModerator.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Moderator/AzureContentSafetyModerator.cs @@ -56,7 +56,7 @@ public async Task ReviewOutputAsync(ITurnContext turnContext, TState turnS { if (command is PredictedSayCommand sayCommand) { - string output = sayCommand.Response; + string output = sayCommand.Response.Content; // If plan is flagged it will be replaced Plan? newPlan = await _HandleTextModeration(output, false); diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Moderator/OpenAIModerator.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Moderator/OpenAIModerator.cs index 4595198e9..663c9e044 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Moderator/OpenAIModerator.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Moderator/OpenAIModerator.cs @@ -65,7 +65,7 @@ public async Task ReviewOutputAsync(ITurnContext turnContext, TState turnS { if (command is PredictedSayCommand sayCommand) { - string output = sayCommand.Response; + string output = sayCommand.Response.Content; // If plan is flagged it will be replaced Plan? newPlan = await _HandleTextModerationAsync(output, false, cancellationToken); diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Planners/PredictedSayCommand.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Planners/PredictedSayCommand.cs index d1b039254..889663509 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Planners/PredictedSayCommand.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Planners/PredictedSayCommand.cs @@ -1,5 +1,7 @@ using Json.Schema; using System.Text.Json.Serialization; +using Microsoft.Teams.AI.AI.Models; +using Microsoft.Teams.AI.Utilities.JsonConverters; namespace Microsoft.Teams.AI.AI.Planners { @@ -17,19 +19,28 @@ public class PredictedSayCommand : IPredictedCommand /// The response that the AI system should say. /// [JsonPropertyName("response")] + [JsonConverter(typeof(ChatMessageJsonConverter))] [JsonRequired] - public string Response { get; set; } + public ChatMessage Response { get; set; } /// /// Creates a new instance of the class. /// /// The response that the AI system should say. [JsonConstructor] - public PredictedSayCommand(string response) + public PredictedSayCommand(ChatMessage response) { Response = response; } + public PredictedSayCommand(string response) + { + Response = new ChatMessage(ChatRole.Assistant) + { + Content = response + }; + } + /// /// Schema /// diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Utilities/JsonConverters/ChatMessageJsonConverter.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Utilities/JsonConverters/ChatMessageJsonConverter.cs new file mode 100644 index 000000000..1797c2928 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Utilities/JsonConverters/ChatMessageJsonConverter.cs @@ -0,0 +1,33 @@ +using Microsoft.Teams.AI.AI.Models; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Teams.AI.Utilities.JsonConverters +{ + internal class ChatMessageJsonConverter : JsonConverter + { + public override ChatMessage Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) + { + string? response = JsonSerializer.Deserialize(ref reader); + + if (response == null) + { + throw new JsonException(); + } + + return new ChatMessage(ChatRole.Assistant) + { + Content = response + }; + } + + public override void Write(Utf8JsonWriter writer, ChatMessage value, JsonSerializerOptions options) + { + writer.WriteStringValue(value.Content); + writer.Flush(); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Utilities/JsonConverters/CommandJsonConverter.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Utilities/JsonConverters/CommandJsonConverter.cs index cb5724127..54544df16 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Utilities/JsonConverters/CommandJsonConverter.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Utilities/JsonConverters/CommandJsonConverter.cs @@ -83,7 +83,7 @@ public override void Write(Utf8JsonWriter writer, IPredictedCommand value, JsonS writer.WritePropertyName(_responsePropertyName); - JsonSerializer.Serialize(writer, ((PredictedSayCommand)value).Response, options); + JsonSerializer.Serialize(writer, ((PredictedSayCommand)value).Response.Content, options); } else {