From 2d49a75e7e0a1255f8833b86bd6cd1b5ba4faa47 Mon Sep 17 00:00:00 2001 From: kavin <115390646+singhk97@users.noreply.github.com> Date: Thu, 25 Jul 2024 10:36:43 -0700 Subject: [PATCH] [C#] bump: Migrate to OpenAI's official .NET SDK (#1858) ## Linked issues closes: #1821 (issue number) ## Details * Migrated from `Azure.AI.OpenAI` to openai's official .NET sdk `OpenAI`. #### Change details * **[BREAKING**] Removed `/completions` completions endpoint support for both OpenAI & AzureOpenAI. It was deprecated by OpenAI June 2023 and virtually no one uses it. * [x] Updated all the samples to migrate to `/chat` endpoint. * Updated default AzureOpenAI api version to `2024-06-01`. * The earliest supported api version is now `2024-04-01-preview`. This is a consequence of using the openai's dotnet sdk `OpenAI`. * Migrated `OpenAIModel` & `OpenAIEmbeddings` class' underlying clients. ## Attestation Checklist - [x] My code follows the style guidelines of this project - I have checked for/fixed spelling, linting, and other errors - I have commented my code for clarity - I have made corresponding changes to the documentation (updating the doc strings in the code is sufficient) - My changes generate no new warnings - I have added tests that validates my changes, and provides sufficient test coverage. I have tested with: - Local testing - E2E testing in Teams - New and existing unit tests pass locally with my changes ### Additional information > Feel free to add other relevant information below --- .../AITests/ChatMessageTests.cs | 223 +++++++++++++ ...ions.cs => ChatCompletionToolCallTests.cs} | 15 +- .../Models/ChatMessageExtensionsTests.cs | 161 +++++---- .../AITests/Models/OpenAIModelTests.cs | 311 +++--------------- ....cs => SequentialDelayRetryPolicyTests.cs} | 24 +- .../AITests/OpenAIEmbeddingsTests.cs | 86 +++-- .../IntegrationTests/OpenAIEmbeddingsTests.cs | 4 +- .../IntegrationTests/OpenAIModelTests.cs | 7 +- .../Microsoft.Teams.AI.Tests.csproj | 3 +- .../TestUtils/TestResponse.cs | 53 +++ .../netstandard2.0/CoverletSourceRootsMapping | Bin 390 -> 0 bytes .../AI/Embeddings/OpenAIEmbeddings.cs | 71 ++-- .../AI/Models/AddHeaderRequestPolicy.cs | 20 +- .../Models/AzureSdkChatMessageExtensions.cs | 67 ---- .../AI/Models/ChatCompletionToolCall.cs | 36 +- .../AI/Models/ChatMessage.cs | 171 +++++++++- .../AI/Models/ChatMessageExtensions.cs | 136 -------- .../AI/Models/MessageContext.cs | 25 +- .../AI/Models/OpenAIModel.cs | 211 +++++------- .../RequestFailedExceptionExtensions.cs | 30 -- .../AI/Models/SequentialDelayRetryPolicy.cs | 23 ++ .../AI/Models/SequentialDelayStrategy.cs | 24 -- .../Exceptions/HttpOperationException.cs | 21 +- .../Microsoft.Teams.AI.csproj | 3 +- .../01.messaging.echoBot/EchoBot.csproj | 2 +- .../SearchCommand.csproj | 2 +- .../TypeAheadBot.csproj | 4 +- .../04.ai.a.teamsChefBot/TeamsChefBot.csproj | 6 +- .../GPT.csproj | 2 +- .../LightBot.csproj | 2 +- .../ListBot.csproj | 2 +- .../DevOpsBot.csproj | 2 +- .../CardGazer.csproj | 2 +- .../TwentyQuestions.csproj | 2 +- .../06.assistants.a.mathBot/MathBot.csproj | 2 +- .../06.assistants.b.orderBot/OrderBot.csproj | 2 +- .../samples/06.auth.oauth.bot/BotAuth.csproj | 2 +- .../MessageExtensionAuth.csproj | 2 +- .../06.auth.teamsSSO.bot/BotAuth.csproj | 2 +- .../MessageExtensionAuth.csproj | 2 +- .../AzureAISearchBot/AzureAISearchBot.csproj | 6 +- .../AzureOpenAIBot.csproj | 6 +- 42 files changed, 913 insertions(+), 862 deletions(-) rename dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/{AzureSdkChatMessageExtensions.cs => ChatCompletionToolCallTests.cs} (63%) rename dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/{SequentialDelayStrategyTests.cs => SequentialDelayRetryPolicyTests.cs} (50%) create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/TestUtils/TestResponse.cs delete mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/bin/Debug/netstandard2.0/CoverletSourceRootsMapping delete mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/AzureSdkChatMessageExtensions.cs delete mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatMessageExtensions.cs delete mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/RequestFailedExceptionExtensions.cs create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/SequentialDelayRetryPolicy.cs delete mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/SequentialDelayStrategy.cs diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/ChatMessageTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/ChatMessageTests.cs index d310837ae..c0a14397d 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/ChatMessageTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/ChatMessageTests.cs @@ -1,4 +1,8 @@ using Microsoft.Teams.AI.AI.Models; +using Microsoft.Teams.AI.Exceptions; +using OpenAI.Chat; +using System.ClientModel.Primitives; +using ChatMessage = Microsoft.Teams.AI.AI.Models.ChatMessage; namespace Microsoft.Teams.AI.Tests.AITests { @@ -28,5 +32,224 @@ public void Test_Get_Content_TypeMismatch_ThrowsException() // Act & Assert Assert.Throws(() => msg.GetContent()); } + + [Fact] + public void Test_Initialization_From_OpenAISdk_ChatMessage() + { + // Arrange + var chatCompletion = ModelReaderWriter.Read(BinaryData.FromString(@$"{{ + ""choices"": [ + {{ + ""finish_reason"": ""stop"", + ""message"": {{ + ""role"": ""assistant"", + ""content"": ""test-choice"", + ""context"": {{ + ""citations"": [ + {{ + ""title"": ""test-title"", + ""url"": ""test-url"", + ""content"": ""test-content"" + }} + ] + }} + }} + }} + ] + }}")); + + // Act + var message = new ChatMessage(chatCompletion!); + + // Assert + Assert.Equal("test-choice", message.Content); + Assert.Equal(ChatRole.Assistant, message.Role); + + var context = message.Context; + Assert.NotNull(context); + Assert.Equal(1, context.Citations.Count); + Assert.Equal("test-title", context.Citations[0].Title); + Assert.Equal("test-url", context.Citations[0].Url); + Assert.Equal("test-content", context.Citations[0].Content); + } + + [Fact] + public void Test_InvalidRole_ToOpenAISdkChatMessage() + { + // Arrange + var chatMessage = new ChatMessage(new ChatRole("InvalidRole")) + { + Content = "test" + }; + + // Act + var ex = Assert.Throws(() => chatMessage.ToOpenAIChatMessage()); + + // Assert + Assert.Equal($"Invalid chat message role: InvalidRole", ex.Message); + } + + [Fact] + public void Test_UserRole_StringContent_ToOpenAISdkChatMessage() + { + // Arrange + var chatMessage = new ChatMessage(ChatRole.User) + { + Content = "test-content", + Name = "author" + }; + + // Act + var result = chatMessage.ToOpenAIChatMessage(); + + // Assert + var userMessage = result as UserChatMessage; + Assert.NotNull(userMessage); + Assert.Equal("test-content", result.Content[0].Text); + // TODO: Uncomment once participant name issue is resolved. + //Assert.Equal("author", userMessage.ParticipantName); + } + + [Fact] + public void Test_UserRole_MultiModalContent_ToOpenAISdkChatMessage() + { + // Arrange + var messageContentParts = new List() { new TextContentPart() { Text = "test" }, new ImageContentPart { ImageUrl = "https://www.testurl.com" } }; + var chatMessage = new ChatMessage(ChatRole.User) + { + Content = messageContentParts, + Name = "author" + }; + + // Act + var result = chatMessage.ToOpenAIChatMessage(); + + // Assert + var userMessage = result as UserChatMessage; + Assert.NotNull(userMessage); + Assert.Equal("test", userMessage.Content[0].Text); + Assert.Equal("https://www.testurl.com", userMessage.Content[1].ImageUri.OriginalString); + + // TODO: Uncomment once participant name issue is resolved. + //Assert.Equal("author", userMessage.ParticipantName); + } + + [Fact] + public void Test_AssistantRole_ToOpenAISdkChatMessage_FunctionCall() + { + // Arrange + var functionCall = new FunctionCall("test-name", "test-arg1"); + var chatMessage = new ChatMessage(ChatRole.Assistant) + { + Content = "test-content", + Name = "test-name", + FunctionCall = functionCall, + }; + + // Act + var result = chatMessage.ToOpenAIChatMessage(); + + // Assert + var assistantMessage = result as AssistantChatMessage; + Assert.NotNull(assistantMessage); + Assert.Equal("test-content", assistantMessage.Content[0].Text); + // TODO: Uncomment when participant name issue is resolved. + //Assert.Equal("test-name", assistantMessage.ParticipantName); + Assert.Equal("test-arg1", assistantMessage.FunctionCall.FunctionArguments); + Assert.Equal("test-name", assistantMessage.FunctionCall.FunctionName); + } + + [Fact] + public void Test_AssistantRole_ToOpenAISdkChatMessage_ToolCall() + { + // Arrange + var chatMessage = new ChatMessage(ChatRole.Assistant) + { + Content = "test-content", + Name = "test-name", + ToolCalls = new List() + { + new ChatCompletionsFunctionToolCall("test-id", "test-tool-name", "test-tool-arg1") + } + }; + + // Act + var result = chatMessage.ToOpenAIChatMessage(); + + // Assert + var assistantMessage = result as AssistantChatMessage; + Assert.NotNull(assistantMessage); + Assert.Equal("test-content", assistantMessage.Content[0].Text); + // TODO: Uncomment when participant name issue is resolved. + //Assert.Equal("test-name", assistantMessage.ParticipantName); + + Assert.Equal(1, assistantMessage.ToolCalls.Count); + ChatToolCall toolCall = assistantMessage.ToolCalls[0]; + Assert.NotNull(toolCall); + Assert.Equal("test-id", toolCall.Id); + Assert.Equal("test-tool-name", toolCall.FunctionName); + Assert.Equal("test-tool-arg1", toolCall.FunctionArguments); + } + + [Fact] + public void Test_SystemRole_ToOpenAISdkChatMessage() + { + // Arrange + var chatMessage = new ChatMessage(ChatRole.System) + { + Content = "test-content", + Name = "author" + }; + + // Act + var result = chatMessage.ToOpenAIChatMessage(); + + // Assert + var systemMessage = result as SystemChatMessage; + Assert.NotNull(systemMessage); + Assert.Equal("test-content", systemMessage.Content[0].Text); + // TODO: Uncomment when participant name issue is resolved. + //Assert.Equal("author", systemMessage.ParticipantName); + } + + [Fact] + public void Test_FunctionRole_ToOpenAISdkChatMessage() + { + // Arrange + var chatMessage = new ChatMessage(ChatRole.Function) + { + Content = "test-content", + Name = "function-name" + }; + + // Act + var result = chatMessage.ToOpenAIChatMessage(); + + // Assert + var functionMessage = result as FunctionChatMessage; + Assert.NotNull(functionMessage); + Assert.Equal("test-content", functionMessage.Content[0].Text); + } + + [Fact] + public void Test_ToolRole_ToOpenAISdkChatMessage() + { + // Arrange + var chatMessage = new ChatMessage(ChatRole.Tool) + { + Content = "test-content", + Name = "tool-name", + ToolCallId = "tool-call-id" + }; + + // Act + var result = chatMessage.ToOpenAIChatMessage(); + + // Assert + var toolMessage = result as ToolChatMessage; + Assert.NotNull(toolMessage); + Assert.Equal("test-content", toolMessage.Content[0].Text); + Assert.Equal("tool-call-id", toolMessage.ToolCallId); + } } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/AzureSdkChatMessageExtensions.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/ChatCompletionToolCallTests.cs similarity index 63% rename from dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/AzureSdkChatMessageExtensions.cs rename to dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/ChatCompletionToolCallTests.cs index c950d98de..449fd34a4 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/AzureSdkChatMessageExtensions.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/ChatCompletionToolCallTests.cs @@ -1,18 +1,19 @@ using Microsoft.Teams.AI.AI.Models; using Microsoft.Teams.AI.Exceptions; +using OpenAI.Chat; namespace Microsoft.Teams.AI.Tests.AITests.Models { - public class AzureSdkChatMessageExtensions + internal class ChatCompletionToolCallTests { [Fact] public void Test_ChatCompletionsToolCall_ToFunctionToolCall() { // Arrange - var functionToolCall = new Azure.AI.OpenAI.ChatCompletionsFunctionToolCall("test-id", "test-name", "test-arg1"); + var functionToolCall = ChatToolCall.CreateFunctionToolCall("test-id", "test-name", "test-arg1"); // Act - var azureSdkFunctionToolCall = functionToolCall.ToChatCompletionsToolCall(); + var azureSdkFunctionToolCall = ChatCompletionsToolCall.FromChatToolCall(functionToolCall); // Assert var toolCall = azureSdkFunctionToolCall as ChatCompletionsFunctionToolCall; @@ -29,15 +30,15 @@ public void Test_ChatCompletionsToolCall_InvalidToolType() var functionToolCall = new InvalidToolCall(); // Act - var ex = Assert.Throws(() => functionToolCall.ToChatCompletionsToolCall()); + var ex = Assert.Throws(() => functionToolCall.ToChatToolCall()); // Assert - Assert.Equal($"Invalid ChatCompletionsToolCall type: {nameof(InvalidToolCall)}", ex.Message); + Assert.Equal("Invalid tool type: invalidToolType", ex.Message); } - private sealed class InvalidToolCall : Azure.AI.OpenAI.ChatCompletionsToolCall + private sealed class InvalidToolCall : ChatCompletionsToolCall { - public InvalidToolCall() : base("test-id") + public InvalidToolCall() : base("invalidToolType", "test-id") { } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/ChatMessageExtensionsTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/ChatMessageExtensionsTests.cs index d42e60b7b..add2c5288 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/ChatMessageExtensionsTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/ChatMessageExtensionsTests.cs @@ -1,153 +1,175 @@ -using Azure.AI.OpenAI; -using Microsoft.Teams.AI.AI.Models; +using Microsoft.Teams.AI.AI.Models; using Microsoft.Teams.AI.Exceptions; +using OpenAI.Chat; +using ChatMessage = Microsoft.Teams.AI.AI.Models.ChatMessage; namespace Microsoft.Teams.AI.Tests.AITests.Models { public class ChatMessageExtensionsTests { [Fact] - public void Test_InvalidRole_ToAzureSdkChatMessage() + public void Test_InvalidRole_ToOpenAISdkChatMessage() { // Arrange - var chatMessage = new ChatMessage(new AI.Models.ChatRole("InvalidRole")) + var chatMessage = new ChatMessage(new ChatRole("InvalidRole")) { Content = "test" }; // Act - var ex = Assert.Throws(() => chatMessage.ToChatRequestMessage()); + var ex = Assert.Throws(() => chatMessage.ToOpenAIChatMessage()); // Assert Assert.Equal($"Invalid chat message role: InvalidRole", ex.Message); } [Fact] - public void Test_UserRole_StringContent_ToAzureSdkChatMessage() + public void Test_UserRole_StringContent_ToOpenAISdkChatMessage() { // Arrange - var chatMessage = new ChatMessage(AI.Models.ChatRole.User) + var chatMessage = new ChatMessage(ChatRole.User) { Content = "test-content", Name = "author" }; // Act - var result = chatMessage.ToChatRequestMessage(); + var result = chatMessage.ToOpenAIChatMessage(); // Assert - Assert.Equal(Azure.AI.OpenAI.ChatRole.User, result.Role); - Assert.Equal(typeof(ChatRequestUserMessage), result.GetType()); - Assert.Equal("test-content", ((ChatRequestUserMessage)result).Content); - Assert.Equal("author", ((ChatRequestUserMessage)result).Name); + var userMessage = result as UserChatMessage; + Assert.NotNull(userMessage); + Assert.Equal("test-content", result.Content[0].Text); + // TODO: Uncomment once participant name issue is resolved. + //Assert.Equal("author", userMessage.ParticipantName); } [Fact] - public void Test_UserRole_MultiModalContent_ToAzureSdkChatMessage() + public void Test_UserRole_MultiModalContent_ToOpenAISdkChatMessage() { // Arrange var messageContentParts = new List() { new TextContentPart() { Text = "test" }, new ImageContentPart { ImageUrl = "https://www.testurl.com" } }; - var chatMessage = new ChatMessage(AI.Models.ChatRole.User) + var chatMessage = new ChatMessage(ChatRole.User) { Content = messageContentParts, Name = "author" }; // Act - var result = chatMessage.ToChatRequestMessage(); + var result = chatMessage.ToOpenAIChatMessage(); // Assert - Assert.Equal(Azure.AI.OpenAI.ChatRole.User, result.Role); - Assert.Equal(typeof(ChatRequestUserMessage), result.GetType()); + var userMessage = result as UserChatMessage; + Assert.NotNull(userMessage); + Assert.Equal("test", userMessage.Content[0].Text); + Assert.Equal("https://www.testurl.com", userMessage.Content[1].ImageUri.OriginalString); - var userMessage = (ChatRequestUserMessage)result; - - Assert.Equal(null, userMessage.Content); - Assert.Equal("test", ((ChatMessageTextContentItem)userMessage.MultimodalContentItems[0]).Text); - Assert.Equal(typeof(ChatMessageImageContentItem), userMessage.MultimodalContentItems[1].GetType()); - Assert.Equal("author", userMessage.Name); + // TODO: Uncomment once participant name issue is resolved. + //Assert.Equal("author", userMessage.ParticipantName); } [Fact] - public void Test_AssistantRole_ToAzureSdkChatMessage() + public void Test_AssistantRole_ToOpenAISdkChatMessage_FunctionCall() { // Arrange - var functionCall = new AI.Models.FunctionCall("test-name", "test-arg1"); - var chatMessage = new ChatMessage(AI.Models.ChatRole.Assistant) + var functionCall = new FunctionCall("test-name", "test-arg1"); + var chatMessage = new ChatMessage(ChatRole.Assistant) { Content = "test-content", Name = "test-name", FunctionCall = functionCall, - ToolCalls = new List() + }; + + // Act + var result = chatMessage.ToOpenAIChatMessage(); + + // Assert + var assistantMessage = result as AssistantChatMessage; + Assert.NotNull(assistantMessage); + Assert.Equal("test-content", assistantMessage.Content[0].Text); + // TODO: Uncomment when participant name issue is resolved. + //Assert.Equal("test-name", assistantMessage.ParticipantName); + Assert.Equal("test-arg1", assistantMessage.FunctionCall.FunctionArguments); + Assert.Equal("test-name", assistantMessage.FunctionCall.FunctionName); + } + + [Fact] + public void Test_AssistantRole_ToOpenAISdkChatMessage_ToolCall() + { + // Arrange + var chatMessage = new ChatMessage(ChatRole.Assistant) + { + Content = "test-content", + Name = "test-name", + ToolCalls = new List() { - new AI.Models.ChatCompletionsFunctionToolCall("test-id", "test-tool-name", "test-tool-arg1") + new ChatCompletionsFunctionToolCall("test-id", "test-tool-name", "test-tool-arg1") } }; // Act - var result = chatMessage.ToChatRequestMessage(); + var result = chatMessage.ToOpenAIChatMessage(); // Assert - Assert.Equal(Azure.AI.OpenAI.ChatRole.Assistant, result.Role); - ChatRequestAssistantMessage? message = result as ChatRequestAssistantMessage; - Assert.NotNull(message); - Assert.Equal("test-content", message.Content); - Assert.Equal("test-name", message.Name); - Assert.Equal("test-arg1", message.FunctionCall.Arguments); - Assert.Equal("test-name", message.FunctionCall.Name); - - Assert.Equal(1, message.ToolCalls.Count); - Azure.AI.OpenAI.ChatCompletionsFunctionToolCall? toolCall = message.ToolCalls[0] as Azure.AI.OpenAI.ChatCompletionsFunctionToolCall; + var assistantMessage = result as AssistantChatMessage; + Assert.NotNull(assistantMessage); + Assert.Equal("test-content", assistantMessage.Content[0].Text); + // TODO: Uncomment when participant name issue is resolved. + //Assert.Equal("test-name", assistantMessage.ParticipantName); + + Assert.Equal(1, assistantMessage.ToolCalls.Count); + ChatToolCall toolCall = assistantMessage.ToolCalls[0]; Assert.NotNull(toolCall); Assert.Equal("test-id", toolCall.Id); - Assert.Equal("test-tool-name", toolCall.Name); - Assert.Equal("test-tool-arg1", toolCall.Arguments); + Assert.Equal("test-tool-name", toolCall.FunctionName); + Assert.Equal("test-tool-arg1", toolCall.FunctionArguments); } [Fact] - public void Test_SystemRole_ToAzureSdkChatMessage() + public void Test_SystemRole_ToOpenAISdkChatMessage() { // Arrange - var chatMessage = new ChatMessage(AI.Models.ChatRole.System) + var chatMessage = new ChatMessage(ChatRole.System) { Content = "test-content", Name = "author" }; // Act - var result = chatMessage.ToChatRequestMessage(); + var result = chatMessage.ToOpenAIChatMessage(); // Assert - Assert.Equal(Azure.AI.OpenAI.ChatRole.System, result.Role); - Assert.Equal(typeof(ChatRequestSystemMessage), result.GetType()); - Assert.Equal("test-content", ((ChatRequestSystemMessage)result).Content); - Assert.Equal("author", ((ChatRequestSystemMessage)result).Name); + var systemMessage = result as SystemChatMessage; + Assert.NotNull(systemMessage); + Assert.Equal("test-content", systemMessage.Content[0].Text); + // TODO: Uncomment when participant name issue is resolved. + //Assert.Equal("author", systemMessage.ParticipantName); } [Fact] - public void Test_FunctionRole_ToAzureSdkChatMessage() + public void Test_FunctionRole_ToOpenAISdkChatMessage() { // Arrange - var chatMessage = new ChatMessage(AI.Models.ChatRole.Function) + var chatMessage = new ChatMessage(ChatRole.Function) { Content = "test-content", Name = "function-name" }; // Act - var result = chatMessage.ToChatRequestMessage(); + var result = chatMessage.ToOpenAIChatMessage(); // Assert - Assert.Equal(Azure.AI.OpenAI.ChatRole.Function, result.Role); - Assert.Equal(typeof(ChatRequestFunctionMessage), result.GetType()); - Assert.Equal("test-content", ((ChatRequestFunctionMessage)result).Content); + var functionMessage = result as FunctionChatMessage; + Assert.NotNull(functionMessage); + Assert.Equal("test-content", functionMessage.Content[0].Text); } [Fact] - public void Test_ToolRole_ToAzureSdkChatMessage() + public void Test_ToolRole_ToOpenAISdkChatMessage() { // Arrange - var chatMessage = new ChatMessage(AI.Models.ChatRole.Tool) + var chatMessage = new ChatMessage(ChatRole.Tool) { Content = "test-content", Name = "tool-name", @@ -155,30 +177,29 @@ public void Test_ToolRole_ToAzureSdkChatMessage() }; // Act - var result = chatMessage.ToChatRequestMessage(); + var result = chatMessage.ToOpenAIChatMessage(); // Assert - Assert.Equal(Azure.AI.OpenAI.ChatRole.Tool, result.Role); - Assert.Equal(typeof(ChatRequestToolMessage), result.GetType()); - Assert.Equal("test-content", ((ChatRequestToolMessage)result).Content); - Assert.Equal("tool-call-id", ((ChatRequestToolMessage)result).ToolCallId); + var toolMessage = result as ToolChatMessage; + Assert.NotNull(toolMessage); + Assert.Equal("test-content", toolMessage.Content[0].Text); + Assert.Equal("tool-call-id", toolMessage.ToolCallId); } [Fact] public void Test_ChatCompletionsToolCall_ToFunctionToolCall() { // Arrange - var functionToolCall = new AI.Models.ChatCompletionsFunctionToolCall("test-id", "test-name", "test-arg1"); + var functionToolCall = new ChatCompletionsFunctionToolCall("test-id", "test-name", "test-arg1"); // Act - var azureSdkFunctionToolCall = functionToolCall.ToAzureSdkChatCompletionsToolCall(); + var chatToolCall = functionToolCall.ToChatToolCall(); // Assert - var toolCall = azureSdkFunctionToolCall as Azure.AI.OpenAI.ChatCompletionsFunctionToolCall; - Assert.NotNull(toolCall); - Assert.Equal("test-id", toolCall.Id); - Assert.Equal("test-name", toolCall.Name); - Assert.Equal("test-arg1", toolCall.Arguments); + Assert.NotNull(chatToolCall); + Assert.Equal("test-id", chatToolCall.Id); + Assert.Equal("test-name", chatToolCall.FunctionName); + Assert.Equal("test-arg1", chatToolCall.FunctionArguments); } [Fact] @@ -188,13 +209,13 @@ public void Test_ChatCompletionsToolCall_InvalidToolType() var functionToolCall = new InvalidToolCall(); // Act - var ex = Assert.Throws(() => functionToolCall.ToAzureSdkChatCompletionsToolCall()); + var ex = Assert.Throws(() => functionToolCall.ToChatToolCall()); // Assert Assert.Equal("Invalid tool type: invalidToolType", ex.Message); } - private sealed class InvalidToolCall : AI.Models.ChatCompletionsToolCall + private sealed class InvalidToolCall : ChatCompletionsToolCall { public InvalidToolCall() : base("invalidToolType", "test-id") { diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/OpenAIModelTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/OpenAIModelTests.cs index d608eee79..6b9d3ea15 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/OpenAIModelTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/OpenAIModelTests.cs @@ -1,7 +1,4 @@ -using Azure; -using Azure.AI.OpenAI; -using Azure.Core; -using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder; using Microsoft.Teams.AI.AI.Models; using Microsoft.Teams.AI.AI.Prompts; using Microsoft.Teams.AI.AI.Prompts.Sections; @@ -9,7 +6,11 @@ using Microsoft.Teams.AI.State; using Microsoft.Teams.AI.Tests.TestUtils; using Moq; -using System.Diagnostics.CodeAnalysis; +using OpenAI; +using OpenAI.Chat; +using OAIChatMessage = OpenAI.Chat.ChatMessage; +using System.ClientModel; +using System.ClientModel.Primitives; using System.Reflection; using ChatMessage = Microsoft.Teams.AI.AI.Models.ChatMessage; using ChatRole = Microsoft.Teams.AI.AI.Models.ChatRole; @@ -35,7 +36,7 @@ public void Test_Constructor_AzureOpenAI_InvalidAzureApiVersion() var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/"); var versions = new List { - "2022-12-01", "2023-05-15", "2023-06-01-preview", "2023-07-01-preview", "2024-02-15-preview", "2024-03-01-preview" + "2024-04-01-preview", "2024-05-01-preview", "2024-06-01" }; // Act @@ -51,137 +52,6 @@ public void Test_Constructor_AzureOpenAI_InvalidAzureApiVersion() Assert.Equal("Model created with an unsupported API version of `2023-12-01-preview`.", exception.Message); } - [Fact] - public async void Test_CompletePromptAsync_AzureOpenAI_Text_PromptTooLong() - { - // Arrange - var turnContextMock = new Mock(); - var turnStateMock = new Mock(); - var renderedPrompt = new RenderedPromptSection(string.Empty, length: 65536, tooLong: true); - var promptMock = new Mock(new List(), -1, true, "\n\n"); - promptMock.Setup((prompt) => prompt.RenderAsTextAsync( - It.IsAny(), It.IsAny(), It.IsAny>>(), - It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); - var promptTemplate = new PromptTemplate("test-prompt", promptMock.Object); - var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/") - { - CompletionType = CompletionConfiguration.CompletionType.Text, - LogRequests = true - }; - var openAIModel = new OpenAIModel(options, loggerFactory: new TestLoggerFactory()); - - // Act - var result = await openAIModel.CompletePromptAsync(turnContextMock.Object, turnStateMock.Object, new PromptManager(), new GPTTokenizer(), promptTemplate); - - // Assert - Assert.Equal(PromptResponseStatus.TooLong, result.Status); - Assert.NotNull(result.Error); - Assert.Equal("The generated text completion prompt had a length of 65536 tokens which exceeded the MaxInputTokens of 2048.", result.Error.Message); - } - - [Fact] - public async void Test_CompletePromptAsync_AzureOpenAI_Text_RateLimited() - { - // Arrange - var turnContextMock = new Mock(); - var turnStateMock = new Mock(); - var renderedPrompt = new RenderedPromptSection(string.Empty, length: 256, tooLong: false); - var promptMock = new Mock(new List(), -1, true, "\n\n"); - promptMock.Setup((prompt) => prompt.RenderAsTextAsync( - It.IsAny(), It.IsAny(), It.IsAny>>(), - It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); - var promptTemplate = new PromptTemplate("test-prompt", promptMock.Object); - var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/") - { - CompletionType = CompletionConfiguration.CompletionType.Text, - LogRequests = true - }; - var clientMock = new Mock(); - var response = new TestResponse(429, "exception"); - var exception = new RequestFailedException(response); - clientMock.Setup((client) => client.GetCompletionsAsync(It.IsAny(), It.IsAny())).ThrowsAsync(exception); - var openAIModel = new OpenAIModel(options, loggerFactory: new TestLoggerFactory()); - openAIModel.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAIModel, clientMock.Object); - - // Act - var result = await openAIModel.CompletePromptAsync(turnContextMock.Object, turnStateMock.Object, new PromptManager(), new GPTTokenizer(), promptTemplate); - - // Assert - Assert.Equal(PromptResponseStatus.RateLimited, result.Status); - Assert.NotNull(result.Error); - Assert.Equal("The text completion API returned a rate limit error.", result.Error.Message); - } - - [Fact] - public async void Test_CompletePromptAsync_AzureOpenAI_Text_RequestFailed() - { - // Arrange - var turnContextMock = new Mock(); - var turnStateMock = new Mock(); - var renderedPrompt = new RenderedPromptSection(string.Empty, length: 256, tooLong: false); - var promptMock = new Mock(new List(), -1, true, "\n\n"); - promptMock.Setup((prompt) => prompt.RenderAsTextAsync( - It.IsAny(), It.IsAny(), It.IsAny>>(), - It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); - var promptTemplate = new PromptTemplate("test-prompt", promptMock.Object); - var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/") - { - CompletionType = CompletionConfiguration.CompletionType.Text, - LogRequests = true, - }; - var clientMock = new Mock(); - var response = new TestResponse(500, "exception"); - var exception = new RequestFailedException(response); - clientMock.Setup((client) => client.GetCompletionsAsync(It.IsAny(), It.IsAny())).ThrowsAsync(exception); - var openAIModel = new OpenAIModel(options, loggerFactory: new TestLoggerFactory()); - openAIModel.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAIModel, clientMock.Object); - - // Act - var result = await openAIModel.CompletePromptAsync(turnContextMock.Object, turnStateMock.Object, new PromptManager(), new GPTTokenizer(), promptTemplate); - - // Assert - Assert.Equal(PromptResponseStatus.Error, result.Status); - Assert.NotNull(result.Error); - Assert.True(result.Error.Message.StartsWith("The text completion API returned an error status of InternalServerError: Service request failed.\r\nStatus: 500 (exception)")); - } - - [Fact] - public async void Test_CompletePromptAsync_AzureOpenAI_Text() - { - // Arrange - var turnContextMock = new Mock(); - var turnStateMock = new Mock(); - var renderedPrompt = new RenderedPromptSection(string.Empty, length: 256, tooLong: false); - var promptMock = new Mock(new List(), -1, true, "\n\n"); - promptMock.Setup((prompt) => prompt.RenderAsTextAsync( - It.IsAny(), It.IsAny(), It.IsAny>>(), - It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); - var promptTemplate = new PromptTemplate("test-prompt", promptMock.Object); - var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/") - { - CompletionType = CompletionConfiguration.CompletionType.Text, - LogRequests = true - }; - var clientMock = new Mock(); - var choice = CreateChoice("test-choice", 0, null, null, null, null); - var usage = CreateCompletionsUsage(0, 0, 0); - var completions = CreateCompletions("test-id", DateTimeOffset.UtcNow, new List { choice }, usage); - Response response = new TestResponse(200, string.Empty); - clientMock.Setup((client) => client.GetCompletionsAsync(It.IsAny(), It.IsAny())).ReturnsAsync(Response.FromValue(completions, response)); - var openAIModel = new OpenAIModel(options, loggerFactory: new TestLoggerFactory()); - openAIModel.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAIModel, clientMock.Object); - - // Act - var result = await openAIModel.CompletePromptAsync(turnContextMock.Object, turnStateMock.Object, new PromptManager(), new GPTTokenizer(), promptTemplate); - - // Assert - Assert.Equal(PromptResponseStatus.Success, result.Status); - Assert.NotNull(result.Message); - Assert.Null(result.Error); - Assert.Equal(ChatRole.Assistant, result.Message.Role); - Assert.Equal("test-choice", result.Message.Content); - } - [Fact] public async void Test_CompletePromptAsync_AzureOpenAI_Chat_PromptTooLong() { @@ -197,7 +67,7 @@ public async void Test_CompletePromptAsync_AzureOpenAI_Chat_PromptTooLong() var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/") { CompletionType = CompletionConfiguration.CompletionType.Chat, - LogRequests = true, + LogRequests = true }; var openAIModel = new OpenAIModel(options, loggerFactory: new TestLoggerFactory()); @@ -219,8 +89,8 @@ public async void Test_CompletePromptAsync_AzureOpenAI_Chat_RateLimited() var renderedPrompt = new RenderedPromptSection>(new List(), length: 256, tooLong: false); var promptMock = new Mock(new List(), -1, true, "\n\n"); promptMock.Setup((prompt) => prompt.RenderAsMessagesAsync( - It.IsAny(), It.IsAny(), It.IsAny>>(), - It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); + It.IsAny(), It.IsAny(), It.IsAny>>(), + It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); var promptTemplate = new PromptTemplate("test-prompt", promptMock.Object); var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/") { @@ -229,8 +99,8 @@ public async void Test_CompletePromptAsync_AzureOpenAI_Chat_RateLimited() }; var clientMock = new Mock(); var response = new TestResponse(429, "exception"); - var exception = new RequestFailedException(response); - clientMock.Setup((client) => client.GetChatCompletionsAsync(It.IsAny(), It.IsAny())).ThrowsAsync(exception); + var exception = new ClientResultException(response); + clientMock.Setup((client) => client.GetChatClient(It.IsAny()).CompleteChatAsync(It.IsAny>(), It.IsAny(), It.IsAny())).ThrowsAsync(exception); var openAIModel = new OpenAIModel(options, loggerFactory: new TestLoggerFactory()); openAIModel.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAIModel, clientMock.Object); @@ -252,18 +122,18 @@ public async void Test_CompletePromptAsync_AzureOpenAI_Chat_RequestFailed() var renderedPrompt = new RenderedPromptSection>(new List(), length: 256, tooLong: false); var promptMock = new Mock(new List(), -1, true, "\n\n"); promptMock.Setup((prompt) => prompt.RenderAsMessagesAsync( - It.IsAny(), It.IsAny(), It.IsAny>>(), - It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); + It.IsAny(), It.IsAny(), It.IsAny>>(), + It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); var promptTemplate = new PromptTemplate("test-prompt", promptMock.Object); var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/") { CompletionType = CompletionConfiguration.CompletionType.Chat, - LogRequests = true + LogRequests = true, }; var clientMock = new Mock(); var response = new TestResponse(500, "exception"); - var exception = new RequestFailedException(response); - clientMock.Setup((client) => client.GetChatCompletionsAsync(It.IsAny(), It.IsAny())).ThrowsAsync(exception); + var exception = new ClientResultException(response); + clientMock.Setup((client) => client.GetChatClient(It.IsAny()).CompleteChatAsync(It.IsAny>(), It.IsAny(), It.IsAny())).ThrowsAsync(exception); var openAIModel = new OpenAIModel(options, loggerFactory: new TestLoggerFactory()); openAIModel.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAIModel, clientMock.Object); @@ -285,21 +155,33 @@ public async void Test_CompletePromptAsync_AzureOpenAI_Chat() var renderedPrompt = new RenderedPromptSection>(new List(), length: 256, tooLong: false); var promptMock = new Mock(new List(), -1, true, "\n\n"); promptMock.Setup((prompt) => prompt.RenderAsMessagesAsync( - It.IsAny(), It.IsAny(), It.IsAny>>(), - It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); + It.IsAny(), It.IsAny(), It.IsAny>>(), + It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync(renderedPrompt); var promptTemplate = new PromptTemplate("test-prompt", promptMock.Object); var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/") { CompletionType = CompletionConfiguration.CompletionType.Chat, - LogRequests = true + LogRequests = true, }; var clientMock = new Mock(); - var chatResponseMessage = CreateChatResponseMessage(Azure.AI.OpenAI.ChatRole.Assistant, "test-choice", null, null, null, null); - var chatChoice = CreateChatChoice(chatResponseMessage, null, 0, null, null, null, null, null, null); - var usage = CreateCompletionsUsage(0, 0, 0); - var chatCompletions = CreateChatCompletions("test-id", DateTimeOffset.UtcNow, new List { chatChoice }, usage); - Response response = new TestResponse(200, string.Empty); - clientMock.Setup((client) => client.GetChatCompletionsAsync(It.IsAny(), It.IsAny())).ReturnsAsync(Response.FromValue(chatCompletions, response!)); + var chatCompletion = ModelReaderWriter.Read(BinaryData.FromString(@$"{{ + ""choices"": [ + {{ + ""finish_reason"": ""stop"", + ""message"": {{ + ""role"": ""assistant"", + ""content"": ""test-choice"" + }} + }} + ] + }}")); + var response = new TestResponse(200, string.Empty); + clientMock.Setup((client) => + client + .GetChatClient(It.IsAny()) + .CompleteChatAsync(It.IsAny>(), It.IsAny(), It.IsAny()) + ).ReturnsAsync(ClientResult.FromValue(chatCompletion!, response)); + var openAIModel = new OpenAIModel(options, loggerFactory: new TestLoggerFactory()); openAIModel.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAIModel, clientMock.Object); @@ -314,122 +196,5 @@ public async void Test_CompletePromptAsync_AzureOpenAI_Chat() Assert.Equal("test-choice", result.Message.Content); } - private static Choice CreateChoice(string text, int index, ContentFilterResultsForChoice? contentFilterResults, CompletionsLogProbabilityModel? logProbabilityModel, CompletionsFinishReason? finishReason, IDictionary? serializedAdditionalRawData) - { - Type[] paramTypes = new Type[] { typeof(string), typeof(int), typeof(ContentFilterResultsForChoice), typeof(CompletionsLogProbabilityModel), typeof(CompletionsFinishReason), typeof(IDictionary) }; - object[] paramValues = new object[] { text, index, contentFilterResults!, logProbabilityModel!, finishReason!, serializedAdditionalRawData! }; - return Construct(paramTypes, paramValues); - } - - private static CompletionsUsage CreateCompletionsUsage(int completionTokens, int promptTokens, int totalTokens) - { - Type[] paramTypes = new Type[] { typeof(int), typeof(int), typeof(int) }; - object[] paramValues = new object[] { completionTokens, promptTokens, totalTokens }; - return Construct(paramTypes, paramValues); - } - - private static Completions CreateCompletions(string id, DateTimeOffset created, IEnumerable choices, CompletionsUsage usage) - { - Type[] paramTypes = new Type[] { typeof(string), typeof(DateTimeOffset), typeof(IEnumerable), typeof(CompletionsUsage) }; - object[] paramValues = new object[] { id, created, choices, usage }; - return Construct(paramTypes, paramValues); - } - - private static ChatResponseMessage CreateChatResponseMessage(Azure.AI.OpenAI.ChatRole role, string content, IReadOnlyList? toolCalls, Azure.AI.OpenAI.FunctionCall? functionCall, AzureChatExtensionsMessageContext? azureExtensionsContext, IDictionary? serializedAdditionalRawData) - { - Type[] paramTypes = new Type[] { typeof(Azure.AI.OpenAI.ChatRole), typeof(string), typeof(IReadOnlyList), typeof(Azure.AI.OpenAI.FunctionCall), typeof(AzureChatExtensionsMessageContext), typeof(IDictionary) }; - object[] paramValues = new object[] { role, content, toolCalls!, functionCall!, azureExtensionsContext!, serializedAdditionalRawData! }; - return Construct(paramTypes, paramValues); - } - - private static ChatChoice CreateChatChoice(ChatResponseMessage message, ChatChoiceLogProbabilityInfo? logProbabilityInfo, int index, CompletionsFinishReason? finishReason, ChatFinishDetails? finishDetails, ChatResponseMessage? internalStreamingDeltaMessage, ContentFilterResultsForChoice? contentFilterResults, AzureChatEnhancements? enhancements, IDictionary? serializedAdditionalRawData) - { - Type[] paramTypes = new Type[] { typeof(ChatResponseMessage), typeof(ChatChoiceLogProbabilityInfo), typeof(int), typeof(CompletionsFinishReason), typeof(ChatFinishDetails), typeof(ChatResponseMessage), typeof(ContentFilterResultsForChoice), typeof(AzureChatEnhancements), typeof(IDictionary) }; - object[] paramValues = new object[] { message, logProbabilityInfo!, index, finishReason!, finishDetails!, internalStreamingDeltaMessage!, contentFilterResults!, enhancements!, serializedAdditionalRawData! }; - return Construct(paramTypes, paramValues); - } - - private static ChatCompletions CreateChatCompletions(string id, DateTimeOffset created, IEnumerable choices, CompletionsUsage usage) - { - Type[] paramTypes = new Type[] { typeof(string), typeof(DateTimeOffset), typeof(IEnumerable), typeof(CompletionsUsage) }; - object[] paramValues = new object[] { id, created, choices, usage }; - return Construct(paramTypes, paramValues); - } - - private static T Construct(Type[] paramTypes, object[] paramValues) - { - Type type = typeof(T); - ConstructorInfo info = type.GetConstructor(BindingFlags.Instance | BindingFlags.NonPublic, null, paramTypes, null)!; - - return (T)info.Invoke(paramValues); - } - } - - public class TestResponse : Response - { - private readonly Dictionary> _headers = new(StringComparer.OrdinalIgnoreCase); - - public TestResponse(int status, string reasonPhrase) - { - Status = status; - ReasonPhrase = reasonPhrase; - ClientRequestId = string.Empty; - } - - public override int Status { get; } - - public override string ReasonPhrase { get; } - - public override Stream? ContentStream { get; set; } - - public override string ClientRequestId { get; set; } - - private bool? _isError; - public override bool IsError => _isError ?? base.IsError; - public void SetIsError(bool value) - { - _isError = value; - } - - public bool IsDisposed { get; private set; } - - protected override bool TryGetHeader(string name, [NotNullWhen(true)] out string? value) - { - if (_headers.TryGetValue(name, out List? values)) - { - value = JoinHeaderValue(values); - return true; - } - - value = null; - return false; - } - - protected override bool TryGetHeaderValues(string name, [NotNullWhen(true)] out IEnumerable? values) - { - var result = _headers.TryGetValue(name, out List? valuesList); - values = valuesList; - return result; - } - - protected override bool ContainsHeader(string name) - { - return TryGetHeaderValues(name, out _); - } - - protected override IEnumerable EnumerateHeaders() - { - return _headers.Select(h => new HttpHeader(h.Key, JoinHeaderValue(h.Value))); - } - - private static string JoinHeaderValue(IEnumerable values) - { - return string.Join(",", values); - } - - public override void Dispose() - { - IsDisposed = true; - } } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/SequentialDelayStrategyTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/SequentialDelayRetryPolicyTests.cs similarity index 50% rename from dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/SequentialDelayStrategyTests.cs rename to dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/SequentialDelayRetryPolicyTests.cs index eb7269390..e83cd2dc2 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/SequentialDelayStrategyTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/Models/SequentialDelayRetryPolicyTests.cs @@ -1,9 +1,9 @@ -using Azure; -using Microsoft.Teams.AI.AI.Models; +using Microsoft.Teams.AI.AI.Models; +using System.ClientModel.Primitives; namespace Microsoft.Teams.AI.Tests.AITests.Models { - public class SequentialDelayStrategyTests + public class SequentialDelayRetryPolicyTests { [Fact] public void Test_SequentialDelayStrategy() @@ -15,13 +15,13 @@ public void Test_SequentialDelayStrategy() TimeSpan.FromMilliseconds(2000), TimeSpan.FromMilliseconds(3000), }; - var strategy = new TestSequentialDelayStrategy(delays); + var strategy = new TestSequentialDelayRetryPolicy(delays); // Act - var result1 = strategy.GetNextDelayCoreMethod(null, 1); - var result2 = strategy.GetNextDelayCoreMethod(null, 2); - var result3 = strategy.GetNextDelayCoreMethod(null, 3); - var result4 = strategy.GetNextDelayCoreMethod(null, 4); + var result1 = strategy.GetNextDelayMethod(null, 1); + var result2 = strategy.GetNextDelayMethod(null, 2); + var result3 = strategy.GetNextDelayMethod(null, 3); + var result4 = strategy.GetNextDelayMethod(null, 4); // Assert Assert.Equal(TimeSpan.FromMilliseconds(1000), result1); @@ -31,15 +31,15 @@ public void Test_SequentialDelayStrategy() } } - internal sealed class TestSequentialDelayStrategy : SequentialDelayStrategy + internal sealed class TestSequentialDelayRetryPolicy : SequentialDelayRetryPolicy { - public TestSequentialDelayStrategy(List delays) : base(delays) + public TestSequentialDelayRetryPolicy(List delays) : base(delays) { } - public TimeSpan GetNextDelayCoreMethod(Response? response, int retryNumber) + public TimeSpan GetNextDelayMethod(PipelineMessage? message, int tryCount) { - return base.GetNextDelayCore(response, retryNumber); + return GetNextDelay(message!, tryCount); } } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/OpenAIEmbeddingsTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/OpenAIEmbeddingsTests.cs index 0d8e59dbf..5dffb4b8c 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/OpenAIEmbeddingsTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/AITests/OpenAIEmbeddingsTests.cs @@ -1,9 +1,12 @@ -using Azure; -using Moq; +using Moq; using System.Reflection; using Microsoft.Teams.AI.AI.Embeddings; -using Azure.AI.OpenAI; +using OpenAI.Embeddings; +using OpenAI; using Microsoft.Teams.AI.Exceptions; +using System.ClientModel; +using Microsoft.Teams.AI.Tests.TestUtils; +using System.ClientModel.Primitives; #pragma warning disable CS8604 // Possible null reference argument. namespace Microsoft.Teams.AI.Tests.AITests @@ -21,15 +24,22 @@ public async void Test_OpenAI_CreateEmbeddings_ReturnEmbeddings() var openAiEmbeddings = new OpenAIEmbeddings(options); IList inputs = new List { "test" }; - var clientMock = new Mock(It.IsAny()); - IEnumerable data = new List() - { - AzureOpenAIModelFactory.EmbeddingItem() - }; - EmbeddingsUsage usage = AzureOpenAIModelFactory.EmbeddingsUsage(); - Embeddings embeddingsResult = AzureOpenAIModelFactory.Embeddings(data, usage); - Response? response = null; - clientMock.Setup(client => client.GetEmbeddingsAsync(It.IsAny(), It.IsAny())).ReturnsAsync(Response.FromValue(embeddingsResult, response)); + var clientMock = new Mock(new ApiKeyCredential(apiKey), It.IsAny()); + var response = new TestResponse(200, string.Empty); + var embeddingCollection = ModelReaderWriter.Read(BinaryData.FromString(@"{ + ""data"": [ + { + ""object"": ""embedding"", + ""index"": 0, + ""embedding"": ""MC4wMDIzMDY0MjU1"" + } + ] + }")); + // MC4wMDIzMDY0MjU1= the base64 encoded float 0.0023064255 + clientMock.Setup(client => client + .GetEmbeddingClient(It.IsAny()) + .GenerateEmbeddingsAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .ReturnsAsync(ClientResult.FromValue(embeddingCollection, response)); openAiEmbeddings.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object); // Act @@ -53,15 +63,22 @@ public async void Test_AzureOpenAI_CreateEmbeddings_ReturnEmbeddings() var openAiEmbeddings = new OpenAIEmbeddings(options); IList inputs = new List { "test" }; - IEnumerable data = new List() - { - AzureOpenAIModelFactory.EmbeddingItem() - }; - EmbeddingsUsage usage = AzureOpenAIModelFactory.EmbeddingsUsage(); - Embeddings embeddingsResult = AzureOpenAIModelFactory.Embeddings(data, usage); - Response? response = null; - var clientMock = new Mock(It.IsAny()); - clientMock.Setup(client => client.GetEmbeddingsAsync(It.IsAny(), It.IsAny())).ReturnsAsync(Response.FromValue(embeddingsResult, response)); + var clientMock = new Mock(new ApiKeyCredential(apiKey), It.IsAny()); + var response = new TestResponse(200, string.Empty); + var embeddingCollection = ModelReaderWriter.Read(BinaryData.FromString(@"{ + ""data"": [ + { + ""object"": ""embedding"", + ""index"": 0, + ""embedding"": ""MC4wMDIzMDY0MjU1"" + } + ] + }")); + // MC4wMDIzMDY0MjU1= the base64 encoded float 0.0023064255 + clientMock.Setup(client => client + .GetEmbeddingClient(It.IsAny()) + .GenerateEmbeddingsAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .ReturnsAsync(ClientResult.FromValue(embeddingCollection, response)); openAiEmbeddings.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object); // Act @@ -76,7 +93,7 @@ public async void Test_AzureOpenAI_CreateEmbeddings_ReturnEmbeddings() [Theory] [InlineData(429, "too many requests", EmbeddingsResponseStatus.RateLimited)] [InlineData(502, "service error", EmbeddingsResponseStatus.Failure)] - public async void Test_CreateEmbeddings_ThrowRequestFailedException(int statusCode, string errorMsg, EmbeddingsResponseStatus responseStatus) + public async void Test_CreateEmbeddings_ThrowClientResultException(int statusCode, string errorMsg, EmbeddingsResponseStatus responseStatus) { // Arrange var apiKey = "randomApiKey"; @@ -86,9 +103,12 @@ public async void Test_CreateEmbeddings_ThrowRequestFailedException(int statusCo var openAiEmbeddings = new OpenAIEmbeddings(options); IList inputs = new List { "test" }; - var exception = new RequestFailedException(statusCode, errorMsg); - var clientMock = new Mock(It.IsAny()); - clientMock.Setup(client => client.GetEmbeddingsAsync(It.IsAny(), It.IsAny())).ThrowsAsync(exception); + var clientMock = new Mock(new ApiKeyCredential(apiKey), It.IsAny()); + var response = new TestResponse(statusCode, errorMsg); + clientMock.Setup(client => client + .GetEmbeddingClient(It.IsAny()) + .GenerateEmbeddingsAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .ThrowsAsync(new ClientResultException(response)); openAiEmbeddings.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object); // Act @@ -96,11 +116,12 @@ public async void Test_CreateEmbeddings_ThrowRequestFailedException(int statusCo // Assert Assert.NotNull(result); + Assert.Equal(responseStatus, result.Status); Assert.Null(result.Output); - Assert.Equal(result.Status, responseStatus); } [Fact] + [System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "CA2201:Do not raise reserved exception types", Justification = "")] public async void Test_CreateEmbeddings_ThrowException() { // Arrange @@ -111,16 +132,19 @@ public async void Test_CreateEmbeddings_ThrowException() var openAiEmbeddings = new OpenAIEmbeddings(options); IList inputs = new List { "test" }; - var exception = new InvalidOperationException("other exception"); - var clientMock = new Mock(It.IsAny()); - clientMock.Setup(client => client.GetEmbeddingsAsync(It.IsAny(), It.IsAny())).ThrowsAsync(exception); + var clientMock = new Mock(new ApiKeyCredential(apiKey), It.IsAny()); + clientMock.Setup(client => client + .GetEmbeddingClient(It.IsAny()) + .GenerateEmbeddingsAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .ThrowsAsync(new Exception("test-exception")); openAiEmbeddings.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAiEmbeddings, clientMock.Object); // Act - var result = await Assert.ThrowsAsync(async () => await openAiEmbeddings.CreateEmbeddingsAsync(inputs)); + var exception = await Assert.ThrowsAsync(async () => await openAiEmbeddings.CreateEmbeddingsAsync(inputs)); // Assert - Assert.Equal("Error while executing openAI Embeddings execution: other exception", result.Message); + Assert.NotNull(exception); + Assert.Equal("Error while executing openAI Embeddings execution: test-exception", exception.Message); } } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/IntegrationTests/OpenAIEmbeddingsTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/IntegrationTests/OpenAIEmbeddingsTests.cs index ce4072201..00c5ab9b0 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/IntegrationTests/OpenAIEmbeddingsTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/IntegrationTests/OpenAIEmbeddingsTests.cs @@ -35,7 +35,7 @@ public OpenAIEmbeddingsTests(ITestOutputHelper output) .Build(); } - [Theory(Skip = "This test should only be run manually.")] + [Fact(Skip = "This test should only be run manually.")] public async Task Test_CreateEmbeddingsAsync_OpenAI() { // Arrange @@ -60,7 +60,7 @@ public async Task Test_CreateEmbeddingsAsync_OpenAI() Assert.Equal(dimension, result.Output[1].Length); } - [Theory(Skip = "This test should only be run manually.")] + [Fact(Skip = "This test should only be run manually.")] public async Task Test_CreateEmbeddingsAsync_AzureOpenAI() { // Arrange diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/IntegrationTests/OpenAIModelTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/IntegrationTests/OpenAIModelTests.cs index 4fe603757..4c6a3eb86 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/IntegrationTests/OpenAIModelTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/IntegrationTests/OpenAIModelTests.cs @@ -43,11 +43,14 @@ public OpenAIModelTests(ITestOutputHelper output) [Theory(Skip = "Should only run manually for now.")] [InlineData("What is the capital of Thailand?", "Bangkok")] - public async Task OpenAIModel_CompletePrompt(string input, string expectedAnswer) + public async Task OpenAIModel_CompleteChatPrompt(string input, string expectedAnswer) { // Arrange var config = _configuration.GetSection("OpenAI").Get(); - var modelOptions = new AI.Models.OpenAIModelOptions(config.ApiKey, config.ChatModelId!); + var modelOptions = new AI.Models.OpenAIModelOptions(config.ApiKey, config.ChatModelId!) + { + CompletionType = CompletionConfiguration.CompletionType.Chat + }; var model = new AI.Models.OpenAIModel(modelOptions); var botAdapterMock = new Mock(); diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Microsoft.Teams.AI.Tests.csproj b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Microsoft.Teams.AI.Tests.csproj index f3315e0e7..4dd21c037 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Microsoft.Teams.AI.Tests.csproj +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Microsoft.Teams.AI.Tests.csproj @@ -11,11 +11,12 @@ - + + diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/TestUtils/TestResponse.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/TestUtils/TestResponse.cs new file mode 100644 index 000000000..5e5140733 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/TestUtils/TestResponse.cs @@ -0,0 +1,53 @@ +using System.ClientModel.Primitives; + +namespace Microsoft.Teams.AI.Tests.TestUtils +{ + public class TestResponse : PipelineResponse + { +#pragma warning disable CS8618 + public TestResponse(int status, string reasonPhrase) +#pragma warning restore CS8618 + { + Status = status; + ReasonPhrase = reasonPhrase; + Content = BinaryData.FromString(""); +#pragma warning disable CS8625 + HeadersCore = null; +#pragma warning restore CS8625 + } + + public override int Status { get; } + + public override string ReasonPhrase { get; } + + public override Stream? ContentStream { get; set; } + + public override BinaryData Content { get; } + + protected override PipelineResponseHeaders HeadersCore { get; } + + private bool? _isError; + public override bool IsError => _isError ?? base.IsError; + public void SetIsError(bool value) + { + _isError = value; + } + + public bool IsDisposed { get; private set; } + + public override BinaryData BufferContent(CancellationToken cancellationToken) + { + return BinaryData.FromString(""); + } + + public override ValueTask BufferContentAsync(CancellationToken cancellationToken) + { + return ValueTask.FromResult(BinaryData.FromString("")); + } + + public override void Dispose() + { + IsDisposed = true; + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/bin/Debug/netstandard2.0/CoverletSourceRootsMapping b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/bin/Debug/netstandard2.0/CoverletSourceRootsMapping deleted file mode 100644 index 5291141dabae762b370bad0d741cdfbbbaed2cb6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 390 zcmb`Cp$@`85JaCQ;ZwK|AVClm0tfC)Q%WfWO1QR)ZwKBXPzXfL?B3kW?C$HiVoqd7 zqT((_&J+wZTzzG%Im#wldKIO*Xsx)(WPN8}aIv3R8=1X7oz6;(>Iz*swU)__jjGa& o>~A@!3T(eAOV*=lf^(7^ { TimeSpan.FromMilliseconds(2000), TimeSpan.FromMilliseconds(5000) } + RetryPolicy = options.RetryPolicy ?? new List { TimeSpan.FromMilliseconds(2000), TimeSpan.FromMilliseconds(5000) }, }; _logger = loggerFactory == null ? NullLogger.Instance : loggerFactory.CreateLogger(); + OpenAIEmbeddingsOptions embeddingsOptions = (OpenAIEmbeddingsOptions)_options; OpenAIClientOptions openAIClientOptions = new() { - RetryPolicy = new RetryPolicy(_options.RetryPolicy!.Count, new SequentialDelayStrategy(_options.RetryPolicy)) + RetryPolicy = new SequentialDelayRetryPolicy(embeddingsOptions.RetryPolicy!, embeddingsOptions.RetryPolicy!.Count) }; - openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), HttpPipelinePosition.PerCall); + + openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), PipelinePosition.PerCall); if (httpClient != null) { - openAIClientOptions.Transport = new HttpClientTransport(httpClient); + openAIClientOptions.Transport = new HttpClientPipelineTransport(httpClient); } - OpenAIEmbeddingsOptions openAIModelOptions = (OpenAIEmbeddingsOptions)_options; - if (!string.IsNullOrEmpty(openAIModelOptions.Organization)) + + if (!string.IsNullOrEmpty(embeddingsOptions.Organization)) { - openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("OpenAI-Organization", openAIModelOptions.Organization!), HttpPipelinePosition.PerCall); + openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("OpenAI-Organization", options.Organization!), PipelinePosition.PerCall); } - _openAIClient = new OpenAIClient(openAIModelOptions.ApiKey, openAIClientOptions); + _openAIClient = new OpenAIClient(new ApiKeyCredential(embeddingsOptions.ApiKey), openAIClientOptions); _deploymentName = options.Model; } @@ -76,7 +79,7 @@ public OpenAIEmbeddings(AzureOpenAIEmbeddingsOptions options, ILoggerFactory? lo Verify.ParamNotNull(options.AzureDeployment, "AzureOpenAIEmbeddingsOptions.AzureDeployment"); Verify.ParamNotNull(options.AzureEndpoint, "AzureOpenAIEmbeddingsOptions.AzureEndpoint"); - string apiVersion = options.AzureApiVersion ?? "2023-05-15"; + string apiVersion = options.AzureApiVersion ?? "2024-06-01"; ServiceVersion? serviceVersion = ConvertStringToServiceVersion(apiVersion); if (serviceVersion == null) { @@ -91,18 +94,20 @@ public OpenAIEmbeddings(AzureOpenAIEmbeddingsOptions options, ILoggerFactory? lo }; _logger = loggerFactory == null ? NullLogger.Instance : loggerFactory.CreateLogger(); - OpenAIClientOptions openAIClientOptions = new(serviceVersion.Value) + + AzureOpenAIEmbeddingsOptions azureEmbeddingsOptions = (AzureOpenAIEmbeddingsOptions)_options; + AzureOpenAIClientOptions azureOpenAIClientOptions = new(serviceVersion.Value) { - RetryPolicy = new RetryPolicy(_options.RetryPolicy!.Count, new SequentialDelayStrategy(_options.RetryPolicy)) + RetryPolicy = new SequentialDelayRetryPolicy(_options.RetryPolicy, _options.RetryPolicy.Count) }; - openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), HttpPipelinePosition.PerCall); + + azureOpenAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), PipelinePosition.PerCall); if (httpClient != null) { - openAIClientOptions.Transport = new HttpClientTransport(httpClient); + azureOpenAIClientOptions.Transport = new HttpClientPipelineTransport(httpClient); } - AzureOpenAIEmbeddingsOptions azureOpenAIModelOptions = (AzureOpenAIEmbeddingsOptions)_options; - _openAIClient = new OpenAIClient(new Uri(azureOpenAIModelOptions.AzureEndpoint), new AzureKeyCredential(azureOpenAIModelOptions.AzureApiKey), openAIClientOptions); + _openAIClient = new AzureOpenAIClient(new Uri(azureEmbeddingsOptions.AzureEndpoint), new ApiKeyCredential(azureEmbeddingsOptions.AzureApiKey), azureOpenAIClientOptions); _deploymentName = options.AzureDeployment; } @@ -114,13 +119,13 @@ public async Task CreateEmbeddingsAsync(IList inputs _logger?.LogInformation($"\nEmbeddings REQUEST: inputs={inputs}"); } - EmbeddingsOptions embeddingsOptions = new(_deploymentName, inputs); + EmbeddingClient embeddingsClient = _openAIClient.GetEmbeddingClient(_deploymentName); try { DateTime startTime = DateTime.Now; - Response response = await _openAIClient.GetEmbeddingsAsync(embeddingsOptions, cancellationToken); - List> embeddingItems = response.Value.Data.OrderBy(item => item.Index).Select(item => item.Embedding).ToList(); + ClientResult response = await embeddingsClient.GenerateEmbeddingsAsync(inputs); + List> embeddingItems = response.Value.OrderBy(item => item.Index).Select(item => item.Vector).ToList(); if (_options.LogRequests!.Value) { @@ -134,7 +139,7 @@ public async Task CreateEmbeddingsAsync(IList inputs Output = embeddingItems, }; } - catch (RequestFailedException ex) when (ex.Status == 429) + catch (ClientResultException ex) when (ex.Status == 429) { return new EmbeddingsResponse { @@ -142,7 +147,7 @@ public async Task CreateEmbeddingsAsync(IList inputs Message = $"The embeddings API returned a rate limit error", }; } - catch (RequestFailedException ex) + catch (ClientResultException ex) { return new EmbeddingsResponse { @@ -158,17 +163,13 @@ public async Task CreateEmbeddingsAsync(IList inputs private ServiceVersion? ConvertStringToServiceVersion(string apiVersion) { - switch (apiVersion) + return apiVersion switch { - case "2022-12-01": return ServiceVersion.V2022_12_01; - case "2023-05-15": return ServiceVersion.V2023_05_15; - case "2023-06-01-preview": return ServiceVersion.V2023_06_01_Preview; - case "2023-07-01-preview": return ServiceVersion.V2023_07_01_Preview; - case "2024-02-15-preview": return ServiceVersion.V2024_02_15_Preview; - case "2024-03-01-preview": return ServiceVersion.V2024_03_01_Preview; - default: - return null; - } + "2024-04-01-preview" => ServiceVersion.V2024_04_01_Preview, + "2024-05-01-preview" => ServiceVersion.V2024_05_01_Preview, + "2024-06-01" => ServiceVersion.V2024_06_01, + _ => null, + }; } } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/AddHeaderRequestPolicy.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/AddHeaderRequestPolicy.cs index 58d0b355f..07a287bba 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/AddHeaderRequestPolicy.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/AddHeaderRequestPolicy.cs @@ -1,25 +1,33 @@ -using Azure.Core.Pipeline; -using Azure.Core; +using System.ClientModel.Primitives; namespace Microsoft.Teams.AI.AI.Models { /// - /// Helper class to inject headers into Azure SDK HTTP pipeline. + /// Helper class to inject headers into HTTP pipeline. /// - internal class AddHeaderRequestPolicy : HttpPipelineSynchronousPolicy + internal class AddHeaderRequestPolicy : PipelinePolicy { private readonly string _headerName; private readonly string _headerValue; - public AddHeaderRequestPolicy(string headerName, string headerValue) + public AddHeaderRequestPolicy(string headerName, string headerValue) : base() { this._headerName = headerName; this._headerValue = headerValue; } - public override void OnSendingRequest(HttpMessage message) + public override ValueTask ProcessAsync(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) { message.Request.Headers.Add(this._headerName, this._headerValue); + + return ProcessNextAsync(message, pipeline, currentIndex); + } + + public override void Process(PipelineMessage message, IReadOnlyList pipeline, int currentIndex) + { + message.Request.Headers.Add(this._headerName, this._headerValue); + + ProcessNext(message, pipeline, currentIndex); } } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/AzureSdkChatMessageExtensions.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/AzureSdkChatMessageExtensions.cs deleted file mode 100644 index e627870e2..000000000 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/AzureSdkChatMessageExtensions.cs +++ /dev/null @@ -1,67 +0,0 @@ -using Azure.AI.OpenAI; -using Microsoft.Teams.AI.Exceptions; - -namespace Microsoft.Teams.AI.AI.Models -{ - /// - /// Provides extension methods for the class. - /// - internal static class AzureSdkChatMessageExtensions - { - /// - /// Converts an to a . - /// - /// The original . - /// A . - public static ChatMessage ToChatMessage(this ChatResponseMessage chatMessage) - { - ChatMessage message = new(new ChatRole(chatMessage.Role.ToString())) - { - Content = chatMessage.Content, - }; - - if (chatMessage.FunctionCall != null) - { - message.Name = chatMessage.FunctionCall.Name; - message.FunctionCall = new FunctionCall(chatMessage.FunctionCall.Name, chatMessage.FunctionCall.Arguments); - } - - if (chatMessage.ToolCalls != null && chatMessage.ToolCalls.Count > 0) - { - message.ToolCalls = new List(); - foreach (Azure.AI.OpenAI.ChatCompletionsToolCall toolCall in chatMessage.ToolCalls) - { - message.ToolCalls.Add(toolCall.ToChatCompletionsToolCall()); - } - - } - - message.Context = new MessageContext(); - if (chatMessage.AzureExtensionsContext?.Intent != null) - { - message.Context.Intent = chatMessage.AzureExtensionsContext.Intent; - } - - IReadOnlyList? citations = chatMessage.AzureExtensionsContext?.Citations; - if (citations != null) - { - foreach (AzureChatExtensionDataSourceResponseCitation citation in citations) - { - message.Context.Citations.Add(new Citation(citation.Content, citation.Title, citation.Url)); - }; - } - - return message; - } - - public static ChatCompletionsToolCall ToChatCompletionsToolCall(this Azure.AI.OpenAI.ChatCompletionsToolCall toolCall) - { - if (toolCall is Azure.AI.OpenAI.ChatCompletionsFunctionToolCall azureFunctionToolCall) - { - return new ChatCompletionsFunctionToolCall(azureFunctionToolCall.Id, azureFunctionToolCall.Name, azureFunctionToolCall.Arguments); - } - - throw new TeamsAIException($"Invalid ChatCompletionsToolCall type: {toolCall.GetType().Name}"); - } - } -} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatCompletionToolCall.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatCompletionToolCall.cs index 7a078aade..2bb8ac0eb 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatCompletionToolCall.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatCompletionToolCall.cs @@ -1,4 +1,6 @@ -using Microsoft.Teams.AI.Utilities; +using Microsoft.Teams.AI.Exceptions; +using Microsoft.Teams.AI.Utilities; +using OpenAI.Chat; namespace Microsoft.Teams.AI.AI.Models { @@ -30,6 +32,38 @@ internal ChatCompletionsToolCall(string type, string id) Type = type; Id = id; } + + /// + /// Maps to OpenAI.Chat.ChatToolCall + /// + /// The mapped OpenAI.Chat.ChatToolCall object. + /// If the tool call type is not valid. + internal ChatToolCall ToChatToolCall() + { + if (this.Type == ToolType.Function) + { + ChatCompletionsFunctionToolCall functionToolCall = (ChatCompletionsFunctionToolCall)this; + return ChatToolCall.CreateFunctionToolCall(functionToolCall.Id, functionToolCall.Name, functionToolCall.Arguments); + } + + throw new TeamsAIException($"Invalid tool type: {this.Type}"); + } + + /// + /// Maps OpenAI.Chat.ChatToolCall to ChatCompletionsToolCall + /// + /// The tool call. + /// The mapped ChatCompletionsToolCall object + /// If the tool call type is not valid. + internal static ChatCompletionsToolCall FromChatToolCall(ChatToolCall toolCall) + { + if (toolCall.Kind == ChatToolCallKind.Function) + { + return new ChatCompletionsFunctionToolCall(toolCall.Id, toolCall.FunctionName, toolCall.FunctionArguments); + } + + throw new TeamsAIException($"Invalid ChatCompletionsToolCall type: {toolCall.GetType().Name}"); + } } /// diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatMessage.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatMessage.cs index 07f0de4e1..5de857ea4 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatMessage.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatMessage.cs @@ -1,4 +1,11 @@ -namespace Microsoft.Teams.AI.AI.Models +using Azure.AI.OpenAI; +using Azure.AI.OpenAI.Chat; +using Microsoft.Teams.AI.Exceptions; +using Microsoft.Teams.AI.Utilities; +using OpenAI.Chat; +using OAI = OpenAI; + +namespace Microsoft.Teams.AI.AI.Models { /// /// Represents a message that will be passed to the Chat Completions API @@ -53,12 +60,172 @@ public TContent GetContent() return (TContent)Content!; } - /// Initializes a new instance of ChatMessage. + /// + /// Initializes a new instance of ChatMessage. + /// /// The role associated with this message payload. public ChatMessage(ChatRole role) { this.Role = role; } + + /// + /// Initializes a new instance of ChatMessage using OpenAI.Chat.ChatCompletion. + /// + /// + internal ChatMessage(ChatCompletion chatCompletion) + { + this.Role = ChatRole.Assistant; + this.Content = chatCompletion.Content[0].Text; + + if (chatCompletion.FunctionCall != null && chatCompletion.FunctionCall.FunctionName != string.Empty) + { + this.Name = chatCompletion.FunctionCall.FunctionName; + this.FunctionCall = new FunctionCall(chatCompletion.FunctionCall.FunctionName, chatCompletion.FunctionCall.FunctionArguments); + } + + if (chatCompletion.ToolCalls != null && chatCompletion.ToolCalls.Count > 0) + { + this.ToolCalls = new List(); + foreach (ChatToolCall toolCall in chatCompletion.ToolCalls) + { + this.ToolCalls.Add(ChatCompletionsToolCall.FromChatToolCall(toolCall)); + } + } + +#pragma warning disable AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + AzureChatMessageContext? azureContext = chatCompletion.GetAzureMessageContext(); +#pragma warning restore AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (azureContext != null) + { + MessageContext? context = new(azureContext); + if (context != null) + { + this.Context = context; + } + } + } + + internal OAI.Chat.ChatMessage ToOpenAIChatMessage() + { + Verify.NotNull(this.Content); + Verify.NotNull(this.Role); + + ChatRole role = this.Role; + OAI.Chat.ChatMessage? message = null; + + string? content = null; + List contentItems = new(); + + // Content is a text + if (this.Content is string textContent) + { + content = textContent; + } + else if (this.Content is IEnumerable contentParts) + { + // Content is has multiple possibly multi-modal parts. + foreach (MessageContentParts contentPart in contentParts) + { + if (contentPart is TextContentPart textPart) + { + contentItems.Add(ChatMessageContentPart.CreateTextMessageContentPart(textPart.Text)); + } + else if (contentPart is ImageContentPart imagePart) + { + contentItems.Add(ChatMessageContentPart.CreateImageMessageContentPart(new Uri(imagePart.ImageUrl))); + } + } + } + + // Different roles map to different classes + if (role == ChatRole.User) + { + UserChatMessage userMessage; + if (content != null) + { + userMessage = new(content); + } + else + { + userMessage = new(contentItems); + } + + if (this.Name != null) + { + // TODO: Currently no way to set `ParticipantName` come and change it eventually. + //userMessage.ParticipantName = this.Name; + } + + message = userMessage; + } + + if (role == ChatRole.Assistant) + { + AssistantChatMessage assistantMessage; + + if (this.FunctionCall != null) + { + ChatFunctionCall functionCall = new(this.FunctionCall.Name ?? "", this.FunctionCall.Arguments ?? ""); + assistantMessage = new AssistantChatMessage(functionCall, this.GetContent()); + } + else if (this.ToolCalls != null) + { + List toolCalls = new(); + foreach (ChatCompletionsToolCall toolCall in this.ToolCalls) + { + toolCalls.Add(toolCall.ToChatToolCall()); + } + assistantMessage = new AssistantChatMessage(toolCalls, this.GetContent()); + } + else + { + assistantMessage = new AssistantChatMessage(this.GetContent()); + } + + if (this.Name != null) + { + // TODO: Currently no way to set `ParticipantName` come and change it eventually. + // assistantMessage.ParticipantName = this.Name; + } + + message = assistantMessage; + } + + if (role == ChatRole.System) + { + SystemChatMessage systemMessage = new(this.GetContent()); + + if (this.Name != null) + { + // TODO: Currently no way to set `ParticipantName` come and change it eventually. + // systemMessage.ParticipantName = chatMessage.Name; + } + + message = systemMessage; + } + + if (role == ChatRole.Function) + { + // TODO: Clean up +#pragma warning disable CS0618 // Type or member is obsolete + message = new FunctionChatMessage(this.Name ?? "", this.GetContent()); +#pragma warning restore CS0618 // Type or member is obsolete + } + + if (role == ChatRole.Tool) + { + + message = new ToolChatMessage(this.ToolCallId ?? "", this.GetContent()); + } + + if (message == null) + { + throw new TeamsAIException($"Invalid chat message role: {role}"); + } + + return message; + } } /// diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatMessageExtensions.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatMessageExtensions.cs deleted file mode 100644 index b0f8d94f5..000000000 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/ChatMessageExtensions.cs +++ /dev/null @@ -1,136 +0,0 @@ -using Azure.AI.OpenAI; -using Microsoft.Teams.AI.Exceptions; -using Microsoft.Teams.AI.Utilities; - -namespace Microsoft.Teams.AI.AI.Models -{ - /// - /// Provides extension methods for the class. - /// - internal static class ChatMessageExtensions - { - /// - /// Converts a to an . - /// - /// The original . - /// An . - public static ChatRequestMessage ToChatRequestMessage(this ChatMessage chatMessage) - { - Verify.NotNull(chatMessage.Content); - Verify.NotNull(chatMessage.Role); - - ChatRole role = chatMessage.Role; - ChatRequestMessage? message = null; - - string? content = null; - List contentItems = new(); - - // Content is a text - if (chatMessage.Content is string textContent) - { - content = textContent; - } - else if (chatMessage.Content is IEnumerable contentParts) - { - // Content is has multiple possibly multi-modal parts. - foreach (MessageContentParts contentPart in contentParts) - { - if (contentPart is TextContentPart textPart) - { - contentItems.Add(new ChatMessageTextContentItem(textPart.Text)); - } - else if (contentPart is ImageContentPart imagePart) - { - contentItems.Add(new ChatMessageImageContentItem(new Uri(imagePart.ImageUrl))); - } - } - } - - // Different roles map to different classes - if (role == ChatRole.User) - { - ChatRequestUserMessage userMessage; - if (content != null) - { - userMessage = new(content); - } - else - { - userMessage = new(contentItems); - } - - if (chatMessage.Name != null) - { - userMessage.Name = chatMessage.Name; - } - - message = userMessage; - } - - if (role == ChatRole.Assistant) - { - ChatRequestAssistantMessage assistantMessage = new(chatMessage.GetContent()); - - if (chatMessage.FunctionCall != null) - { - assistantMessage.FunctionCall = new(chatMessage.FunctionCall.Name ?? "", chatMessage.FunctionCall.Arguments ?? ""); - } - - if (chatMessage.ToolCalls != null) - { - foreach (ChatCompletionsToolCall toolCall in chatMessage.ToolCalls) - { - assistantMessage.ToolCalls.Add(toolCall.ToAzureSdkChatCompletionsToolCall()); - } - } - - if (chatMessage.Name != null) - { - assistantMessage.Name = chatMessage.Name; - } - - message = assistantMessage; - } - - if (role == ChatRole.System) - { - ChatRequestSystemMessage systemMessage = new(chatMessage.GetContent()); - - if (chatMessage.Name != null) - { - systemMessage.Name = chatMessage.Name; - } - - message = systemMessage; - } - - if (role == ChatRole.Function) - { - message = new ChatRequestFunctionMessage(chatMessage.Name ?? "", chatMessage.GetContent()); - } - - if (role == ChatRole.Tool) - { - message = new ChatRequestToolMessage(chatMessage.GetContent(), chatMessage.ToolCallId ?? ""); - } - - if (message == null) - { - throw new TeamsAIException($"Invalid chat message role: {role}"); - } - - return message; - } - - public static Azure.AI.OpenAI.ChatCompletionsToolCall ToAzureSdkChatCompletionsToolCall(this ChatCompletionsToolCall toolCall) - { - if (toolCall.Type == ToolType.Function) - { - ChatCompletionsFunctionToolCall functionToolCall = (ChatCompletionsFunctionToolCall)toolCall; - return new Azure.AI.OpenAI.ChatCompletionsFunctionToolCall(functionToolCall.Id, functionToolCall.Name, functionToolCall.Arguments); - } - - throw new TeamsAIException($"Invalid tool type: {toolCall.Type}"); - } - } -} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/MessageContext.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/MessageContext.cs index 33c83c771..10f808f4b 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/MessageContext.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/MessageContext.cs @@ -1,4 +1,5 @@ -using Microsoft.Teams.AI.Utilities; +using Azure.AI.OpenAI.Chat; +using Microsoft.Teams.AI.Utilities; namespace Microsoft.Teams.AI.AI.Models { @@ -16,6 +17,28 @@ public class MessageContext /// The intent of the message. /// public string Intent { get; set; } = string.Empty; + + /// + /// Creates a MessageContext + /// + public MessageContext() { } + + /// + /// Creates a MessageContext using OpenAI.Chat.AzureChatMessageContext. + /// + /// + internal MessageContext(AzureChatMessageContext azureContext) + { + if (azureContext.Citations != null) + { + foreach (AzureChatCitation citation in azureContext.Citations) + { + this.Citations.Add(new Citation(citation.Content, citation.Title, citation.Url)); + } + } + + this.Intent = azureContext.Intent; + } } /// diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/OpenAIModel.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/OpenAIModel.cs index 5d26f7f68..b89afe27b 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/OpenAIModel.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/OpenAIModel.cs @@ -1,8 +1,4 @@ -using Azure; -using Azure.AI.OpenAI; -using Azure.Core; -using Azure.Core.Pipeline; -using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Teams.AI.AI.Prompts; @@ -14,8 +10,14 @@ using System.ClientModel.Primitives; using System.Net; using System.Text.Json; -using static Azure.AI.OpenAI.OpenAIClientOptions; +using OpenAI; +using OAIChat = OpenAI.Chat; +using Azure.AI.OpenAI; using static Microsoft.Teams.AI.AI.Prompts.CompletionConfiguration; +using System.ClientModel; +using ServiceVersion = Azure.AI.OpenAI.AzureOpenAIClientOptions.ServiceVersion; +using Azure.AI.OpenAI.Chat; +using OpenAI.Chat; namespace Microsoft.Teams.AI.AI.Models { @@ -29,7 +31,8 @@ public class OpenAIModel : IPromptCompletionModel private readonly OpenAIClient _openAIClient; private readonly string _deploymentName; - private readonly static JsonSerializerOptions _serializerOptions = new() + private readonly bool _useAzure; + private static readonly JsonSerializerOptions _serializerOptions = new() { WriteIndented = true, Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping @@ -49,6 +52,7 @@ public OpenAIModel(OpenAIModelOptions options, ILoggerFactory? loggerFactory = n Verify.ParamNotNull(options.ApiKey, "OpenAIModelOptions.ApiKey"); Verify.ParamNotNull(options.DefaultModel, "OpenAIModelOptions.DefaultModel"); + _useAzure = false; _options = new OpenAIModelOptions(options.ApiKey, options.DefaultModel) { Organization = options.Organization, @@ -61,19 +65,20 @@ public OpenAIModel(OpenAIModelOptions options, ILoggerFactory? loggerFactory = n OpenAIClientOptions openAIClientOptions = new() { - RetryPolicy = new RetryPolicy(_options.RetryPolicy!.Count, new SequentialDelayStrategy(_options.RetryPolicy)) + RetryPolicy = new SequentialDelayRetryPolicy(_options.RetryPolicy, _options.RetryPolicy.Count) }; - openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), HttpPipelinePosition.PerCall); + + openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), PipelinePosition.PerCall); if (httpClient != null) { - openAIClientOptions.Transport = new HttpClientTransport(httpClient); + openAIClientOptions.Transport = new HttpClientPipelineTransport(httpClient); } OpenAIModelOptions openAIModelOptions = (OpenAIModelOptions)_options; if (!string.IsNullOrEmpty(openAIModelOptions.Organization)) { - openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("OpenAI-Organization", openAIModelOptions.Organization!), HttpPipelinePosition.PerCall); + openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("OpenAI-Organization", openAIModelOptions.Organization!), PipelinePosition.PerCall); } - _openAIClient = new OpenAIClient(openAIModelOptions.ApiKey, openAIClientOptions); + _openAIClient = new OpenAIClient(new ApiKeyCredential(openAIModelOptions.ApiKey), openAIClientOptions); _deploymentName = options.DefaultModel; } @@ -90,13 +95,14 @@ public OpenAIModel(AzureOpenAIModelOptions options, ILoggerFactory? loggerFactor Verify.ParamNotNull(options.AzureApiKey, "AzureOpenAIModelOptions.AzureApiKey"); Verify.ParamNotNull(options.AzureDefaultDeployment, "AzureOpenAIModelOptions.AzureDefaultDeployment"); Verify.ParamNotNull(options.AzureEndpoint, "AzureOpenAIModelOptions.AzureEndpoint"); - string apiVersion = options.AzureApiVersion ?? "2024-02-15-preview"; + string apiVersion = options.AzureApiVersion ?? "2024-06-01"; ServiceVersion? serviceVersion = ConvertStringToServiceVersion(apiVersion); if (serviceVersion == null) { throw new ArgumentException($"Model created with an unsupported API version of `{apiVersion}`."); } + _useAzure = true; _options = new AzureOpenAIModelOptions(options.AzureApiKey, options.AzureDefaultDeployment, options.AzureEndpoint) { AzureApiVersion = apiVersion, @@ -107,17 +113,18 @@ public OpenAIModel(AzureOpenAIModelOptions options, ILoggerFactory? loggerFactor }; _logger = loggerFactory == null ? NullLogger.Instance : loggerFactory.CreateLogger(); - OpenAIClientOptions openAIClientOptions = new(serviceVersion.Value) + AzureOpenAIClientOptions azureOpenAIClientOptions = new(serviceVersion.Value) { - RetryPolicy = new RetryPolicy(_options.RetryPolicy!.Count, new SequentialDelayStrategy(_options.RetryPolicy)) + RetryPolicy = new SequentialDelayRetryPolicy(_options.RetryPolicy, _options.RetryPolicy.Count) }; - openAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), HttpPipelinePosition.PerCall); + + azureOpenAIClientOptions.AddPolicy(new AddHeaderRequestPolicy("User-Agent", _userAgent), PipelinePosition.PerCall); if (httpClient != null) { - openAIClientOptions.Transport = new HttpClientTransport(httpClient); + azureOpenAIClientOptions.Transport = new HttpClientPipelineTransport(httpClient); } AzureOpenAIModelOptions azureOpenAIModelOptions = (AzureOpenAIModelOptions)_options; - _openAIClient = new OpenAIClient(new Uri(azureOpenAIModelOptions.AzureEndpoint), new AzureKeyCredential(azureOpenAIModelOptions.AzureApiKey), openAIClientOptions); + _openAIClient = new AzureOpenAIClient(new Uri(azureOpenAIModelOptions.AzureEndpoint), new ApiKeyCredential(azureOpenAIModelOptions.AzureApiKey), azureOpenAIClientOptions); _deploymentName = options.AzureDefaultDeployment; } @@ -129,83 +136,7 @@ public async Task CompletePromptAsync(ITurnContext turnContext, int maxInputTokens = promptTemplate.Configuration.Completion.MaxInputTokens; - if (_options.CompletionType == CompletionType.Text) - { - // Render prompt - RenderedPromptSection prompt = await promptTemplate.Prompt.RenderAsTextAsync(turnContext, memory, promptFunctions, tokenizer, maxInputTokens, cancellationToken); - if (prompt.TooLong) - { - return new PromptResponse - { - Status = PromptResponseStatus.TooLong, - Error = new($"The generated text completion prompt had a length of {prompt.Length} tokens which exceeded the MaxInputTokens of {maxInputTokens}.") - }; - } - if (_options.LogRequests!.Value) - { - // TODO: Colorize - _logger.LogTrace("PROMPT:"); - _logger.LogTrace(prompt.Output); - } - - CompletionsOptions completionsOptions = new(_deploymentName, new List { prompt.Output }) - { - MaxTokens = maxInputTokens, - Temperature = (float)promptTemplate.Configuration.Completion.Temperature, - NucleusSamplingFactor = (float)promptTemplate.Configuration.Completion.TopP, - PresencePenalty = (float)promptTemplate.Configuration.Completion.PresencePenalty, - FrequencyPenalty = (float)promptTemplate.Configuration.Completion.FrequencyPenalty, - }; - - Response? rawResponse; - Response? completionsResponse = null; - PromptResponse promptResponse = new(); - try - { - completionsResponse = await _openAIClient.GetCompletionsAsync(completionsOptions, cancellationToken); - rawResponse = completionsResponse.GetRawResponse(); - promptResponse.Status = PromptResponseStatus.Success; - promptResponse.Message = new ChatMessage(ChatRole.Assistant) - { - Content = completionsResponse.Value.Choices[0].Text - }; - } - catch (RequestFailedException e) - { - rawResponse = e.GetRawResponse(); - HttpOperationException httpOperationException = e.ToHttpOperationException(); - if (httpOperationException.StatusCode == (HttpStatusCode)429) - { - promptResponse.Status = PromptResponseStatus.RateLimited; - promptResponse.Error = new("The text completion API returned a rate limit error."); - } - else - { - promptResponse.Status = PromptResponseStatus.Error; - promptResponse.Error = new($"The text completion API returned an error status of {httpOperationException.StatusCode}: {httpOperationException.Message}"); - } - } - - if (_options.LogRequests!.Value) - { - // TODO: Colorize - _logger.LogTrace("RESPONSE:"); - _logger.LogTrace($"status {rawResponse!.Status}"); - _logger.LogTrace($"duration {(DateTime.UtcNow - startTime).TotalMilliseconds} ms"); - if (promptResponse.Status == PromptResponseStatus.Success) - { - _logger.LogTrace(JsonSerializer.Serialize(completionsResponse!.Value, _serializerOptions)); - } - if (promptResponse.Status == PromptResponseStatus.RateLimited) - { - _logger.LogTrace("HEADERS:"); - _logger.LogTrace(JsonSerializer.Serialize(rawResponse.Headers, _serializerOptions)); - } - } - - return promptResponse; - } - else + if (_options.CompletionType == CompletionType.Chat) { // Render prompt RenderedPromptSection> prompt = await promptTemplate.Prompt.RenderAsMessagesAsync(turnContext, memory, promptFunctions, tokenizer, maxInputTokens, cancellationToken); @@ -238,34 +169,53 @@ public async Task CompletePromptAsync(ITurnContext turnContext, } // Call chat completion API - IEnumerable chatMessages = prompt.Output.Select(chatMessage => chatMessage.ToChatRequestMessage()); - ChatCompletionsOptions chatCompletionsOptions = new(_deploymentName, chatMessages) + IEnumerable chatMessages = prompt.Output.Select(chatMessage => chatMessage.ToOpenAIChatMessage()); + + ChatCompletionOptions? chatCompletionOptions = ModelReaderWriter.Read(BinaryData.FromString($@"{{ + ""max_tokens"": {maxInputTokens}, + ""temperature"": {(float)promptTemplate.Configuration.Completion.Temperature}, + ""top_p"": {(float)promptTemplate.Configuration.Completion.TopP}, + ""presence_penalty"": {(float)promptTemplate.Configuration.Completion.PresencePenalty}, + ""frequency_penalty"": {(float)promptTemplate.Configuration.Completion.FrequencyPenalty} + }}")); + + if (chatCompletionOptions == null) { - MaxTokens = maxInputTokens, - Temperature = (float)promptTemplate.Configuration.Completion.Temperature, - NucleusSamplingFactor = (float)promptTemplate.Configuration.Completion.TopP, - PresencePenalty = (float)promptTemplate.Configuration.Completion.PresencePenalty, - FrequencyPenalty = (float)promptTemplate.Configuration.Completion.FrequencyPenalty, - }; + throw new TeamsAIException("Failed to create chat completions options"); + } + + // TODO: Use this once setters are added for the following fields in `OpenAI` package. + //OAIChat.ChatCompletionOptions chatCompletionsOptions = new() + //{ + // MaxTokens = maxInputTokens, + // Temperature = (float)promptTemplate.Configuration.Completion.Temperature, + // TopP = (float)promptTemplate.Configuration.Completion.TopP, + // PresencePenalty = (float)promptTemplate.Configuration.Completion.PresencePenalty, + // FrequencyPenalty = (float)promptTemplate.Configuration.Completion.FrequencyPenalty, + //}; IDictionary? additionalData = promptTemplate.Configuration.Completion.AdditionalData; - AddAzureChatExtensionConfigurations(chatCompletionsOptions, additionalData); + if (_useAzure) + { + AddAzureChatExtensionConfigurations(chatCompletionOptions, additionalData); + } - Response? rawResponse; - Response? chatCompletionsResponse = null; + PipelineResponse? rawResponse; + ClientResult? chatCompletionsResponse = null; PromptResponse promptResponse = new(); try { - chatCompletionsResponse = await _openAIClient.GetChatCompletionsAsync(chatCompletionsOptions, cancellationToken); + chatCompletionsResponse = await _openAIClient.GetChatClient(_deploymentName).CompleteChatAsync(chatMessages, chatCompletionOptions, cancellationToken); rawResponse = chatCompletionsResponse.GetRawResponse(); promptResponse.Status = PromptResponseStatus.Success; - promptResponse.Message = chatCompletionsResponse.Value.Choices[0].Message.ToChatMessage(); + promptResponse.Message = new ChatMessage(chatCompletionsResponse.Value); promptResponse.Input = input; } - catch (RequestFailedException e) + catch (ClientResultException e) { + // TODO: Verify if RequestFailedException is thrown when request fails. rawResponse = e.GetRawResponse(); - HttpOperationException httpOperationException = e.ToHttpOperationException(); + HttpOperationException httpOperationException = new(e); if (httpOperationException.StatusCode == (HttpStatusCode)429) { promptResponse.Status = PromptResponseStatus.RateLimited; @@ -294,27 +244,26 @@ public async Task CompletePromptAsync(ITurnContext turnContext, _logger.LogTrace(JsonSerializer.Serialize(rawResponse.Headers, _serializerOptions)); } } - return promptResponse; } + else + { + throw new TeamsAIException("The legacy completion endpoint has been deprecated, please use the chat completions endpoint instead"); + } } private ServiceVersion? ConvertStringToServiceVersion(string apiVersion) { - switch (apiVersion) + return apiVersion switch { - case "2022-12-01": return ServiceVersion.V2022_12_01; - case "2023-05-15": return ServiceVersion.V2023_05_15; - case "2023-06-01-preview": return ServiceVersion.V2023_06_01_Preview; - case "2023-07-01-preview": return ServiceVersion.V2023_07_01_Preview; - case "2024-02-15-preview": return ServiceVersion.V2024_02_15_Preview; - case "2024-03-01-preview": return ServiceVersion.V2024_03_01_Preview; - default: - return null; - } + "2024-04-01-preview" => ServiceVersion.V2024_04_01_Preview, + "2024-05-01-preview" => ServiceVersion.V2024_05_01_Preview, + "2024-06-01" => ServiceVersion.V2024_06_01, + _ => null, + }; } - private void AddAzureChatExtensionConfigurations(ChatCompletionsOptions options, IDictionary? additionalData) + private void AddAzureChatExtensionConfigurations(OAIChat.ChatCompletionOptions options, IDictionary? additionalData) { if (additionalData == null) { @@ -323,23 +272,15 @@ private void AddAzureChatExtensionConfigurations(ChatCompletionsOptions options, if (additionalData != null && additionalData.TryGetValue("data_sources", out JsonElement array)) { - List configurations = new(); List entries = array.Deserialize>()!; foreach (object item in entries) { - AzureChatExtensionConfiguration? dataSourceItem = ModelReaderWriter.Read(BinaryData.FromObjectAsJson(item)); - if (dataSourceItem != null) - { - configurations.Add(dataSourceItem); - } - } - - if (configurations.Count > 0) - { - options.AzureExtensionsOptions = new(); - foreach (AzureChatExtensionConfiguration configuration in configurations) + AzureChatDataSource? dataSource = ModelReaderWriter.Read(BinaryData.FromObjectAsJson(item)); + if (dataSource != null) { - options.AzureExtensionsOptions.Extensions.Add(configuration); +#pragma warning disable AOAI001 + options.AddDataSource(dataSource); +#pragma warning restore AOAI001 } } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/RequestFailedExceptionExtensions.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/RequestFailedExceptionExtensions.cs deleted file mode 100644 index d800dac25..000000000 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/RequestFailedExceptionExtensions.cs +++ /dev/null @@ -1,30 +0,0 @@ -using Azure; -using Microsoft.Teams.AI.Exceptions; -using System.Net; - -namespace Microsoft.Teams.AI.AI.Models -{ - /// - /// Provides extension methods for the class. - /// - internal static class RequestFailedExceptionExtensions - { - /// - /// Converts a to an . - /// - /// The original . - /// An instance. - public static HttpOperationException ToHttpOperationException(this RequestFailedException exception) - { - string? responseContent = null; - - try - { - responseContent = exception.GetRawResponse()?.Content?.ToString(); - } - catch { } - - return new HttpOperationException(exception.Message, (HttpStatusCode)exception.Status, responseContent); - } - } -} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/SequentialDelayRetryPolicy.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/SequentialDelayRetryPolicy.cs new file mode 100644 index 000000000..f1f2f44d1 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/SequentialDelayRetryPolicy.cs @@ -0,0 +1,23 @@ +using System.ClientModel.Primitives; + +namespace Microsoft.Teams.AI.AI.Models +{ + /// + /// A customized delay retry policy that uses a fixed sequence of delays that are iterated through as the number of retries increases. + /// + internal class SequentialDelayRetryPolicy : ClientRetryPolicy + { + private List _delays; + + public SequentialDelayRetryPolicy(List delays, int maxRetries = 3) : base(maxRetries) + { + this._delays = delays; + } + + protected override TimeSpan GetNextDelay(PipelineMessage message, int tryCount) + { + int index = tryCount - 1; + return index >= _delays.Count ? _delays[_delays.Count - 1] : _delays[index]; + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/SequentialDelayStrategy.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/SequentialDelayStrategy.cs deleted file mode 100644 index 120a40a1c..000000000 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/AI/Models/SequentialDelayStrategy.cs +++ /dev/null @@ -1,24 +0,0 @@ -using Azure; -using Azure.Core; - -namespace Microsoft.Teams.AI.AI.Models -{ - /// - /// A customized delay strategy that uses a fixed sequence of delays that are iterated through as the number of retries increases. - /// - internal class SequentialDelayStrategy : DelayStrategy - { - private List _delays; - - public SequentialDelayStrategy(List delays) - { - this._delays = delays; - } - - protected override TimeSpan GetNextDelayCore(Response? response, int retryNumber) - { - int index = retryNumber - 1; - return index >= _delays.Count ? _delays[_delays.Count - 1] : _delays[index]; - } - } -} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Exceptions/HttpOperationException.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Exceptions/HttpOperationException.cs index a3adf9a96..d0bd49613 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Exceptions/HttpOperationException.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Exceptions/HttpOperationException.cs @@ -1,4 +1,5 @@ -using System.Net; +using System.ClientModel; +using System.Net; namespace Microsoft.Teams.AI.Exceptions { @@ -29,6 +30,24 @@ public HttpOperationException(string message, HttpStatusCode? httpStatusCode = n ResponseContent = responseContent; } + /// + /// Create an instance of the HttpOperationException class using the ClientResultException class + /// + /// The client result exception. + internal HttpOperationException(ClientResultException exception) : base(exception.Message) + { + string? responseContent = null; + + try + { + responseContent = exception.GetRawResponse()?.Content?.ToString(); + } + catch { } + + StatusCode = (HttpStatusCode)exception.Status; + ResponseContent = responseContent; + } + /// /// Checks status code is a http error status code. /// diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Microsoft.Teams.AI.csproj b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Microsoft.Teams.AI.csproj index c68ba4353..16828fb7e 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Microsoft.Teams.AI.csproj +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Microsoft.Teams.AI.csproj @@ -38,7 +38,7 @@ - + @@ -47,6 +47,7 @@ + diff --git a/dotnet/samples/01.messaging.echoBot/EchoBot.csproj b/dotnet/samples/01.messaging.echoBot/EchoBot.csproj index 151e097bf..cce7ec0c1 100644 --- a/dotnet/samples/01.messaging.echoBot/EchoBot.csproj +++ b/dotnet/samples/01.messaging.echoBot/EchoBot.csproj @@ -16,7 +16,7 @@ - + diff --git a/dotnet/samples/02.messageExtensions.a.searchCommand/SearchCommand.csproj b/dotnet/samples/02.messageExtensions.a.searchCommand/SearchCommand.csproj index 7dcfead4c..a0939b873 100644 --- a/dotnet/samples/02.messageExtensions.a.searchCommand/SearchCommand.csproj +++ b/dotnet/samples/02.messageExtensions.a.searchCommand/SearchCommand.csproj @@ -13,7 +13,7 @@ - + diff --git a/dotnet/samples/03.adaptiveCards.a.typeAheadBot/TypeAheadBot.csproj b/dotnet/samples/03.adaptiveCards.a.typeAheadBot/TypeAheadBot.csproj index ada043ed1..e73f46eed 100644 --- a/dotnet/samples/03.adaptiveCards.a.typeAheadBot/TypeAheadBot.csproj +++ b/dotnet/samples/03.adaptiveCards.a.typeAheadBot/TypeAheadBot.csproj @@ -11,8 +11,8 @@ - - + + diff --git a/dotnet/samples/04.ai.a.teamsChefBot/TeamsChefBot.csproj b/dotnet/samples/04.ai.a.teamsChefBot/TeamsChefBot.csproj index c728d139f..249e30f73 100644 --- a/dotnet/samples/04.ai.a.teamsChefBot/TeamsChefBot.csproj +++ b/dotnet/samples/04.ai.a.teamsChefBot/TeamsChefBot.csproj @@ -11,9 +11,9 @@ - - - + + + diff --git a/dotnet/samples/04.ai.b.messageExtensions.gptME/GPT.csproj b/dotnet/samples/04.ai.b.messageExtensions.gptME/GPT.csproj index e51167826..7d2e3702f 100644 --- a/dotnet/samples/04.ai.b.messageExtensions.gptME/GPT.csproj +++ b/dotnet/samples/04.ai.b.messageExtensions.gptME/GPT.csproj @@ -14,7 +14,7 @@ - + diff --git a/dotnet/samples/04.ai.c.actionMapping.lightBot/LightBot.csproj b/dotnet/samples/04.ai.c.actionMapping.lightBot/LightBot.csproj index d4fef2ce6..823365ba7 100644 --- a/dotnet/samples/04.ai.c.actionMapping.lightBot/LightBot.csproj +++ b/dotnet/samples/04.ai.c.actionMapping.lightBot/LightBot.csproj @@ -17,7 +17,7 @@ - + diff --git a/dotnet/samples/04.ai.d.chainedActions.listBot/ListBot.csproj b/dotnet/samples/04.ai.d.chainedActions.listBot/ListBot.csproj index bfd1f9c60..54acfaf58 100644 --- a/dotnet/samples/04.ai.d.chainedActions.listBot/ListBot.csproj +++ b/dotnet/samples/04.ai.d.chainedActions.listBot/ListBot.csproj @@ -11,7 +11,7 @@ - + diff --git a/dotnet/samples/04.ai.e.chainedActions.devOpsBot/DevOpsBot.csproj b/dotnet/samples/04.ai.e.chainedActions.devOpsBot/DevOpsBot.csproj index 263137910..2dbdfac63 100644 --- a/dotnet/samples/04.ai.e.chainedActions.devOpsBot/DevOpsBot.csproj +++ b/dotnet/samples/04.ai.e.chainedActions.devOpsBot/DevOpsBot.csproj @@ -13,7 +13,7 @@ - + diff --git a/dotnet/samples/04.ai.f.vision.cardMaster/CardGazer.csproj b/dotnet/samples/04.ai.f.vision.cardMaster/CardGazer.csproj index 1692da6ea..1caf6aaff 100644 --- a/dotnet/samples/04.ai.f.vision.cardMaster/CardGazer.csproj +++ b/dotnet/samples/04.ai.f.vision.cardMaster/CardGazer.csproj @@ -17,7 +17,7 @@ - + diff --git a/dotnet/samples/04.e.twentyQuestions/TwentyQuestions.csproj b/dotnet/samples/04.e.twentyQuestions/TwentyQuestions.csproj index d73bf6122..b9e449434 100644 --- a/dotnet/samples/04.e.twentyQuestions/TwentyQuestions.csproj +++ b/dotnet/samples/04.e.twentyQuestions/TwentyQuestions.csproj @@ -12,7 +12,7 @@ - + diff --git a/dotnet/samples/06.assistants.a.mathBot/MathBot.csproj b/dotnet/samples/06.assistants.a.mathBot/MathBot.csproj index 426a47abc..6c7c765e7 100644 --- a/dotnet/samples/06.assistants.a.mathBot/MathBot.csproj +++ b/dotnet/samples/06.assistants.a.mathBot/MathBot.csproj @@ -12,7 +12,7 @@ - + diff --git a/dotnet/samples/06.assistants.b.orderBot/OrderBot.csproj b/dotnet/samples/06.assistants.b.orderBot/OrderBot.csproj index 48dc5dcc2..dcd18bd4e 100644 --- a/dotnet/samples/06.assistants.b.orderBot/OrderBot.csproj +++ b/dotnet/samples/06.assistants.b.orderBot/OrderBot.csproj @@ -13,7 +13,7 @@ - + diff --git a/dotnet/samples/06.auth.oauth.bot/BotAuth.csproj b/dotnet/samples/06.auth.oauth.bot/BotAuth.csproj index 11f7181b3..0072e58ec 100644 --- a/dotnet/samples/06.auth.oauth.bot/BotAuth.csproj +++ b/dotnet/samples/06.auth.oauth.bot/BotAuth.csproj @@ -16,7 +16,7 @@ - + diff --git a/dotnet/samples/06.auth.oauth.messageExtension/MessageExtensionAuth.csproj b/dotnet/samples/06.auth.oauth.messageExtension/MessageExtensionAuth.csproj index 476f68dda..ac8dc4e7c 100644 --- a/dotnet/samples/06.auth.oauth.messageExtension/MessageExtensionAuth.csproj +++ b/dotnet/samples/06.auth.oauth.messageExtension/MessageExtensionAuth.csproj @@ -13,7 +13,7 @@ - + diff --git a/dotnet/samples/06.auth.teamsSSO.bot/BotAuth.csproj b/dotnet/samples/06.auth.teamsSSO.bot/BotAuth.csproj index 01abf8656..5155915e1 100644 --- a/dotnet/samples/06.auth.teamsSSO.bot/BotAuth.csproj +++ b/dotnet/samples/06.auth.teamsSSO.bot/BotAuth.csproj @@ -16,7 +16,7 @@ - + diff --git a/dotnet/samples/06.auth.teamsSSO.messageExtension/MessageExtensionAuth.csproj b/dotnet/samples/06.auth.teamsSSO.messageExtension/MessageExtensionAuth.csproj index 97fdd0128..ea4b33de4 100644 --- a/dotnet/samples/06.auth.teamsSSO.messageExtension/MessageExtensionAuth.csproj +++ b/dotnet/samples/06.auth.teamsSSO.messageExtension/MessageExtensionAuth.csproj @@ -13,7 +13,7 @@ - + diff --git a/dotnet/samples/08.datasource.azureaisearch/AzureAISearchBot/AzureAISearchBot.csproj b/dotnet/samples/08.datasource.azureaisearch/AzureAISearchBot/AzureAISearchBot.csproj index 57782e173..59c0a08de 100644 --- a/dotnet/samples/08.datasource.azureaisearch/AzureAISearchBot/AzureAISearchBot.csproj +++ b/dotnet/samples/08.datasource.azureaisearch/AzureAISearchBot/AzureAISearchBot.csproj @@ -12,9 +12,9 @@ - - - + + + diff --git a/dotnet/samples/08.datasource.azureopenai/AzureOpenAIBot.csproj b/dotnet/samples/08.datasource.azureopenai/AzureOpenAIBot.csproj index 092c1fa3f..7a77a5383 100644 --- a/dotnet/samples/08.datasource.azureopenai/AzureOpenAIBot.csproj +++ b/dotnet/samples/08.datasource.azureopenai/AzureOpenAIBot.csproj @@ -11,9 +11,9 @@ - - - + + +