From c11ab2981b42d4c8f4055cebb06edda563da5da3 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:08:13 +0100 Subject: [PATCH] .Net: New Azure AI Inference Connector (#7963) # Motivation and Context This PR brings support for Azure AI Studio Model Catalogs also deployed thru GitHub Models, this Connector uses the `Azure AI Inference SDK` library client. Closes #3992 Closes #7958 --- .github/workflows/dotnet-build-and-test.yml | 2 + .../0051-dotnet-azure-model-as-a-service.md | 46 ++ dotnet/Directory.Packages.props | 1 + dotnet/SK-dotnet.sln | 18 + .../AzureAIInference_ChatCompletion.cs | 97 +++ ...zureAIInference_ChatCompletionStreaming.cs | 176 +++++ .../Google_GeminiChatCompletion.cs | 4 +- .../Google_GeminiChatCompletionStreaming.cs | 4 +- .../ChatCompletion/OpenAI_ChatCompletion.cs | 4 +- .../OpenAI_ChatCompletionStreaming.cs | 4 +- dotnet/samples/Concepts/Concepts.csproj | 1 + .../Demos/AIModelRouter/AIModelRouter.csproj | 1 + dotnet/samples/Demos/AIModelRouter/Program.cs | 30 +- ...nnectors.AzureAIInference.UnitTests.csproj | 48 ++ .../Core/ChatClientCoreTests.cs | 184 +++++ ...AIInferenceKernelBuilderExtensionsTests.cs | 49 ++ ...ferenceServiceCollectionExtensionsTests.cs | 50 ++ ...reAIInferenceChatCompletionServiceTests.cs | 280 ++++++++ ...AIInferencePromptExecutionSettingsTests.cs | 240 +++++++ .../TestData/chat_completion_response.json | 22 + .../chat_completion_streaming_response.txt | 7 + .../AssemblyInfo.cs | 6 + .../Connectors.AzureAIInference.csproj | 34 + .../Core/AddHeaderRequestPolicy.cs | 20 + .../Core/ChatClientCore.cs | 649 ++++++++++++++++++ .../Core/RequestFailedExceptionExtensions.cs | 38 + ...AzureAIInferenceKernelBuilderExtensions.cs | 86 +++ ...eAIInferenceServiceCollectionExtensions.cs | 106 +++ .../AzureAIInferenceChatCompletionService.cs | 96 +++ ...AzureAIInferencePromptExecutionSettings.cs | 281 ++++++++ .../Core/ClientCoreTests.cs | 6 +- ...reAIInferenceChatCompletionServiceTests.cs | 255 +++++++ .../IntegrationTests/IntegrationTests.csproj | 1 + dotnet/src/IntegrationTests/README.md | 4 + .../AzureAIInferenceConfiguration.cs | 15 + dotnet/src/IntegrationTests/testsettings.json | 5 + .../InternalUtilities/TestConfiguration.cs | 8 + .../ChatCompletionServiceExtensions.cs | 6 +- .../ChatPlugin/Chat/config.json | 4 +- .../FunPlugin/Limerick/config.json | 6 +- 40 files changed, 2871 insertions(+), 23 deletions(-) create mode 100644 docs/decisions/0051-dotnet-azure-model-as-a-service.md create mode 100644 dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletion.cs create mode 100644 dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletionStreaming.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Core/ChatClientCoreTests.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Settings/AzureAIInferencePromptExecutionSettingsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/TestData/chat_completion_response.json create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/TestData/chat_completion_streaming_response.txt create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference/AssemblyInfo.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference/Core/AddHeaderRequestPolicy.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference/Core/ChatClientCore.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference/Core/RequestFailedExceptionExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureAIInference/Settings/AzureAIInferencePromptExecutionSettings.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletionServiceTests.cs create mode 100644 dotnet/src/IntegrationTests/TestSettings/AzureAIInferenceConfiguration.cs diff --git a/.github/workflows/dotnet-build-and-test.yml b/.github/workflows/dotnet-build-and-test.yml index 441164296b70..f46adb441a41 100644 --- a/.github/workflows/dotnet-build-and-test.yml +++ b/.github/workflows/dotnet-build-and-test.yml @@ -125,6 +125,8 @@ jobs: Bing__ApiKey: ${{ secrets.BING__APIKEY }} OpenAI__ApiKey: ${{ secrets.OPENAI__APIKEY }} OpenAI__ChatModelId: ${{ vars.OPENAI__CHATMODELID }} + AzureAIInference__ApiKey: ${{ secrets.AZUREAIINFERENCE__APIKEY }} + AzureAIInference__Endpoint: ${{ secrets.AZUREAIINFERENCE__ENDPOINT }} # Generate test reports and check coverage - name: Generate test reports diff --git a/docs/decisions/0051-dotnet-azure-model-as-a-service.md b/docs/decisions/0051-dotnet-azure-model-as-a-service.md new file mode 100644 index 000000000000..b023838d5128 --- /dev/null +++ b/docs/decisions/0051-dotnet-azure-model-as-a-service.md @@ -0,0 +1,46 @@ +--- +# These are optional elements. Feel free to remove any of them. +status: proposed +contact: rogerbarreto +date: 2024-08-07 +deciders: rogerbarreto, markwallace-microsoft +consulted: taochen +--- + +# Support Connector for .Net Azure Model-as-a-Service (Azure AI Studio) + +## Context and Problem Statement + +There has been a demand from customers to use and support natively models deployed in [Azure AI Studio - Serverless APIs](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/model-catalog-overview#model-deployment-managed-compute-and-serverless-api-pay-as-you-go), This mode of consumption operates on a pay-as-you-go basis, typically using tokens for billing purposes. Clients can access the service via the [Azure AI Model Inference API](https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-api?tabs=azure-studio) or client SDKs. + +At present, there is no official support for [Azure AI Studio](https://learn.microsoft.com/en-us/azure/ai-studio/what-is-ai-studio). The purpose of this ADR is to examine the constraints of the service and explore potential solutions to enable support for the service via the development of a new AI connector. + +## Azure Inference Client library for .NET + +The Azure team has a new client library, namely [Azure.AI.Inference](https://github.com/Azure/azure-sdk-for-net/blob/Azure.AI.Inference_1.0.0-beta.1/sdk/ai/Azure.AI.Inference/README.md) in .Net, for effectively interacting with the service. While the service API is OpenAI-compatible, it is not permissible to use the OpenAI and the Azure OpenAI client libraries for interacting with the service as they are not independent with respect to both the models and their providers. This is because Azure AI Studio features a diverse range of open-source models, other than OpenAI models. + +### Limitations + +Currently is known that the first version of the client SDK will only support: `Chat Completion` and `Text Embedding Generation` and `Image Embedding Generation` with `TextToImage Generation` planned. + +There are no current plans to support `Text Generation` modality. + +## AI Connector + +### Namespace options + +- `Microsoft.SemanticKernel.Connectors.AzureAI` +- `Microsoft.SemanticKernel.Connectors.AzureAIInference` +- `Microsoft.SemanticKernel.Connectors.AzureAIModelInference` + +Decision: `Microsoft.SemanticKernel.Connectors.AzureAIInference` + +### Support for model-specific parameters + +Models can possess supplementary parameters that are not part of the default API. The service API and the client SDK enable the provision of model-specific parameters. Users can provide model-specific settings via a dedicated argument along with other settings, such as `temperature` and `top_p`, among others. + +Azure AI Inference specialized `PromptExecutionSettings`, will support those customizable parameters. + +### Feature Branch + +The development of the Azure AI Inference connector will be done in a feature branch named `feature-connectors-azureaiinference`. diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 5fc5b81af480..de902f1fa7cd 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -5,6 +5,7 @@ true + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 3a7241ac500c..ebabff7f3ceb 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -334,6 +334,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.AzureOpenAI", "s EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.AzureOpenAI.UnitTests", "src\Connectors\Connectors.AzureOpenAI.UnitTests\Connectors.AzureOpenAI.UnitTests.csproj", "{8CF06B22-50F3-4F71-A002-622DB49DF0F5}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.AzureAIInference", "src\Connectors\Connectors.AzureAIInference\Connectors.AzureAIInference.csproj", "{063044B2-A901-43C5-BFDF-5E4E71C7BC33}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.AzureAIInference.UnitTests", "src\Connectors\Connectors.AzureAIInference.UnitTests\Connectors.AzureAIInference.UnitTests.csproj", "{E0D45DDB-6D32-40FC-AC79-E1F342C4F513}" +EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OnnxSimpleRAG", "samples\Demos\OnnxSimpleRAG\OnnxSimpleRAG.csproj", "{8972254B-B8F0-4119-953B-378E3BACA59A}" EndProject Global @@ -853,6 +857,18 @@ Global {8CF06B22-50F3-4F71-A002-622DB49DF0F5}.Publish|Any CPU.Build.0 = Debug|Any CPU {8CF06B22-50F3-4F71-A002-622DB49DF0F5}.Release|Any CPU.ActiveCfg = Release|Any CPU {8CF06B22-50F3-4F71-A002-622DB49DF0F5}.Release|Any CPU.Build.0 = Release|Any CPU + {063044B2-A901-43C5-BFDF-5E4E71C7BC33}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {063044B2-A901-43C5-BFDF-5E4E71C7BC33}.Debug|Any CPU.Build.0 = Debug|Any CPU + {063044B2-A901-43C5-BFDF-5E4E71C7BC33}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {063044B2-A901-43C5-BFDF-5E4E71C7BC33}.Publish|Any CPU.Build.0 = Publish|Any CPU + {063044B2-A901-43C5-BFDF-5E4E71C7BC33}.Release|Any CPU.ActiveCfg = Release|Any CPU + {063044B2-A901-43C5-BFDF-5E4E71C7BC33}.Release|Any CPU.Build.0 = Release|Any CPU + {E0D45DDB-6D32-40FC-AC79-E1F342C4F513}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E0D45DDB-6D32-40FC-AC79-E1F342C4F513}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E0D45DDB-6D32-40FC-AC79-E1F342C4F513}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {E0D45DDB-6D32-40FC-AC79-E1F342C4F513}.Publish|Any CPU.Build.0 = Debug|Any CPU + {E0D45DDB-6D32-40FC-AC79-E1F342C4F513}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E0D45DDB-6D32-40FC-AC79-E1F342C4F513}.Release|Any CPU.Build.0 = Release|Any CPU {8972254B-B8F0-4119-953B-378E3BACA59A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {8972254B-B8F0-4119-953B-378E3BACA59A}.Debug|Any CPU.Build.0 = Debug|Any CPU {8972254B-B8F0-4119-953B-378E3BACA59A}.Publish|Any CPU.ActiveCfg = Debug|Any CPU @@ -975,6 +991,8 @@ Global {36DDC119-C030-407E-AC51-A877E9E0F660} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} {7AAD7388-307D-41FB-B80A-EF9E3A4E31F0} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} {8CF06B22-50F3-4F71-A002-622DB49DF0F5} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} + {063044B2-A901-43C5-BFDF-5E4E71C7BC33} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} + {E0D45DDB-6D32-40FC-AC79-E1F342C4F513} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} {8972254B-B8F0-4119-953B-378E3BACA59A} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution diff --git a/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletion.cs b/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletion.cs new file mode 100644 index 000000000000..38f2add47fa6 --- /dev/null +++ b/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletion.cs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; + +namespace ChatCompletion; + +// The following example shows how to use Semantic Kernel with Azure AI Inference / Azure AI Studio +public class AzureAIInference_ChatCompletion(ITestOutputHelper output) : BaseTest(output) +{ + [Fact] + public async Task ServicePromptAsync() + { + Console.WriteLine("======== Azure AI Inference - Chat Completion ========"); + + var chatService = new AzureAIInferenceChatCompletionService( + endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), + apiKey: TestConfiguration.AzureAIInference.ApiKey); + + Console.WriteLine("Chat content:"); + Console.WriteLine("------------------------"); + + var chatHistory = new ChatHistory("You are a librarian, expert about books"); + + // First user message + chatHistory.AddUserMessage("Hi, I'm looking for book suggestions"); + OutputLastMessage(chatHistory); + + // First assistant message + var reply = await chatService.GetChatMessageContentAsync(chatHistory); + chatHistory.Add(reply); + OutputLastMessage(chatHistory); + + // Second user message + chatHistory.AddUserMessage("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); + OutputLastMessage(chatHistory); + + // Second assistant message + reply = await chatService.GetChatMessageContentAsync(chatHistory); + chatHistory.Add(reply); + OutputLastMessage(chatHistory); + + /* Output: + + Chat content: + ------------------------ + System: You are a librarian, expert about books + ------------------------ + User: Hi, I'm looking for book suggestions + ------------------------ + Assistant: Sure, I'd be happy to help! What kind of books are you interested in? Fiction or non-fiction? Any particular genre? + ------------------------ + User: I love history and philosophy, I'd like to learn something new about Greece, any suggestion? + ------------------------ + Assistant: Great! For history and philosophy books about Greece, here are a few suggestions: + + 1. "The Greeks" by H.D.F. Kitto - This is a classic book that provides an overview of ancient Greek history and culture, including their philosophy, literature, and art. + + 2. "The Republic" by Plato - This is one of the most famous works of philosophy in the Western world, and it explores the nature of justice and the ideal society. + + 3. "The Peloponnesian War" by Thucydides - This is a detailed account of the war between Athens and Sparta in the 5th century BCE, and it provides insight into the political and military strategies of the time. + + 4. "The Iliad" by Homer - This epic poem tells the story of the Trojan War and is considered one of the greatest works of literature in the Western canon. + + 5. "The Histories" by Herodotus - This is a comprehensive account of the Persian Wars and provides a wealth of information about ancient Greek culture and society. + + I hope these suggestions are helpful! + ------------------------ + */ + } + + [Fact] + public async Task ChatPromptAsync() + { + StringBuilder chatPrompt = new(""" + You are a librarian, expert about books + Hi, I'm looking for book suggestions + """); + + var kernel = Kernel.CreateBuilder() + .AddAzureAIInferenceChatCompletion( + endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), + apiKey: TestConfiguration.AzureAIInference.ApiKey) + .Build(); + + var reply = await kernel.InvokePromptAsync(chatPrompt.ToString()); + + chatPrompt.AppendLine($""); + chatPrompt.AppendLine("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); + + reply = await kernel.InvokePromptAsync(chatPrompt.ToString()); + + Console.WriteLine(reply); + } +} diff --git a/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletionStreaming.cs b/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletionStreaming.cs new file mode 100644 index 000000000000..62c1fd3dcb11 --- /dev/null +++ b/dotnet/samples/Concepts/ChatCompletion/AzureAIInference_ChatCompletionStreaming.cs @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; + +namespace ChatCompletion; + +/// +/// These examples demonstrate the ways different content types are streamed by OpenAI LLM via the chat completion service. +/// +public class AzureAIInference_ChatCompletionStreaming(ITestOutputHelper output) : BaseTest(output) +{ + /// + /// This example demonstrates chat completion streaming using OpenAI. + /// + [Fact] + public Task StreamChatAsync() + { + Console.WriteLine("======== Azure AI Inference - Chat Completion Streaming ========"); + + var chatService = new AzureAIInferenceChatCompletionService( + endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), + apiKey: TestConfiguration.AzureAIInference.ApiKey); + + return this.StartStreamingChatAsync(chatService); + } + + /// + /// This example demonstrates chat completion streaming using OpenAI via the kernel. + /// + [Fact] + public async Task StreamChatPromptAsync() + { + Console.WriteLine("======== Azure AI Inference - Chat Prompt Completion Streaming ========"); + + StringBuilder chatPrompt = new(""" + You are a librarian, expert about books + Hi, I'm looking for book suggestions + """); + + var kernel = Kernel.CreateBuilder() + .AddAzureAIInferenceChatCompletion( + endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), + apiKey: TestConfiguration.AzureAIInference.ApiKey) + .Build(); + + var reply = await StreamMessageOutputFromKernelAsync(kernel, chatPrompt.ToString()); + + chatPrompt.AppendLine($""); + chatPrompt.AppendLine("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); + + reply = await StreamMessageOutputFromKernelAsync(kernel, chatPrompt.ToString()); + + Console.WriteLine(reply); + } + + /// + /// This example demonstrates how the chat completion service streams text content. + /// It shows how to access the response update via StreamingChatMessageContent.Content property + /// and alternatively via the StreamingChatMessageContent.Items property. + /// + [Fact] + public async Task StreamTextFromChatAsync() + { + Console.WriteLine("======== Stream Text from Chat Content ========"); + + // Create chat completion service + var chatService = new AzureAIInferenceChatCompletionService( + endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint), + apiKey: TestConfiguration.AzureAIInference.ApiKey); + + // Create chat history with initial system and user messages + ChatHistory chatHistory = new("You are a librarian, an expert on books."); + chatHistory.AddUserMessage("Hi, I'm looking for book suggestions."); + chatHistory.AddUserMessage("I love history and philosophy. I'd like to learn something new about Greece, any suggestion?"); + + // Start streaming chat based on the chat history + await foreach (StreamingChatMessageContent chatUpdate in chatService.GetStreamingChatMessageContentsAsync(chatHistory)) + { + // Access the response update via StreamingChatMessageContent.Content property + Console.Write(chatUpdate.Content); + + // Alternatively, the response update can be accessed via the StreamingChatMessageContent.Items property + Console.Write(chatUpdate.Items.OfType().FirstOrDefault()); + } + } + + /// + /// Starts streaming chat with the chat completion service. + /// + /// The chat completion service instance. + private async Task StartStreamingChatAsync(IChatCompletionService chatCompletionService) + { + Console.WriteLine("Chat content:"); + Console.WriteLine("------------------------"); + + var chatHistory = new ChatHistory("You are a librarian, expert about books"); + OutputLastMessage(chatHistory); + + // First user message + chatHistory.AddUserMessage("Hi, I'm looking for book suggestions"); + OutputLastMessage(chatHistory); + + // First assistant message + await StreamMessageOutputAsync(chatCompletionService, chatHistory, AuthorRole.Assistant); + + // Second user message + chatHistory.AddUserMessage("I love history and philosophy, I'd like to learn something new about Greece, any suggestion?"); + OutputLastMessage(chatHistory); + + // Second assistant message + await StreamMessageOutputAsync(chatCompletionService, chatHistory, AuthorRole.Assistant); + } + + /// + /// Streams the message output from the chat completion service. + /// + /// The chat completion service instance. + /// The chat history instance. + /// The author role. + private async Task StreamMessageOutputAsync(IChatCompletionService chatCompletionService, ChatHistory chatHistory, AuthorRole authorRole) + { + bool roleWritten = false; + string fullMessage = string.Empty; + + await foreach (var chatUpdate in chatCompletionService.GetStreamingChatMessageContentsAsync(chatHistory)) + { + if (!roleWritten && chatUpdate.Role.HasValue) + { + Console.Write($"{chatUpdate.Role.Value}: {chatUpdate.Content}"); + roleWritten = true; + } + + if (chatUpdate.Content is { Length: > 0 }) + { + fullMessage += chatUpdate.Content; + Console.Write(chatUpdate.Content); + } + } + + Console.WriteLine("\n------------------------"); + chatHistory.AddMessage(authorRole, fullMessage); + } + + /// + /// Outputs the chat history by streaming the message output from the kernel. + /// + /// The kernel instance. + /// The prompt message. + /// The full message output from the kernel. + private async Task StreamMessageOutputFromKernelAsync(Kernel kernel, string prompt) + { + bool roleWritten = false; + string fullMessage = string.Empty; + + await foreach (var chatUpdate in kernel.InvokePromptStreamingAsync(prompt)) + { + if (!roleWritten && chatUpdate.Role.HasValue) + { + Console.Write($"{chatUpdate.Role.Value}: {chatUpdate.Content}"); + roleWritten = true; + } + + if (chatUpdate.Content is { Length: > 0 }) + { + fullMessage += chatUpdate.Content; + Console.Write(chatUpdate.Content); + } + } + + Console.WriteLine("\n------------------------"); + return fullMessage; + } +} diff --git a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs index 2e8f750e5476..f5963698ce0d 100644 --- a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs +++ b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs @@ -96,7 +96,7 @@ private async Task SimpleChatAsync(Kernel kernel) chatHistory.AddUserMessage("Hi, I'm looking for new power tools, any suggestion?"); await MessageOutputAsync(chatHistory); - // First bot assistant message + // First assistant message var reply = await chat.GetChatMessageContentAsync(chatHistory); chatHistory.Add(reply); await MessageOutputAsync(chatHistory); @@ -105,7 +105,7 @@ private async Task SimpleChatAsync(Kernel kernel) chatHistory.AddUserMessage("I'm looking for a drill, a screwdriver and a hammer."); await MessageOutputAsync(chatHistory); - // Second bot assistant message + // Second assistant message reply = await chat.GetChatMessageContentAsync(chatHistory); chatHistory.Add(reply); await MessageOutputAsync(chatHistory); diff --git a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs index 803a6b6fafcd..2b6f7b1f7556 100644 --- a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs +++ b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs @@ -97,7 +97,7 @@ private async Task StreamingChatAsync(Kernel kernel) chatHistory.AddUserMessage("Hi, I'm looking for alternative coffee brew methods, can you help me?"); await MessageOutputAsync(chatHistory); - // First bot assistant message + // First assistant message var streamingChat = chat.GetStreamingChatMessageContentsAsync(chatHistory); var reply = await MessageOutputAsync(streamingChat); chatHistory.Add(reply); @@ -106,7 +106,7 @@ private async Task StreamingChatAsync(Kernel kernel) chatHistory.AddUserMessage("Give me the best speciality coffee roasters."); await MessageOutputAsync(chatHistory); - // Second bot assistant message + // Second assistant message streamingChat = chat.GetStreamingChatMessageContentsAsync(chatHistory); reply = await MessageOutputAsync(streamingChat); chatHistory.Add(reply); diff --git a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletion.cs b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletion.cs index a92c86dd977d..b8825e332c97 100644 --- a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletion.cs +++ b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletion.cs @@ -91,7 +91,7 @@ private async Task StartChatAsync(IChatCompletionService chatGPT) chatHistory.AddUserMessage("Hi, I'm looking for book suggestions"); OutputLastMessage(chatHistory); - // First bot assistant message + // First assistant message var reply = await chatGPT.GetChatMessageContentAsync(chatHistory); chatHistory.Add(reply); OutputLastMessage(chatHistory); @@ -100,7 +100,7 @@ private async Task StartChatAsync(IChatCompletionService chatGPT) chatHistory.AddUserMessage("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); OutputLastMessage(chatHistory); - // Second bot assistant message + // Second assistant message reply = await chatGPT.GetChatMessageContentAsync(chatHistory); chatHistory.Add(reply); OutputLastMessage(chatHistory); diff --git a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletionStreaming.cs b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletionStreaming.cs index fe0052a52db2..c6888bcedd25 100644 --- a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletionStreaming.cs +++ b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletionStreaming.cs @@ -128,14 +128,14 @@ private async Task StartStreamingChatAsync(IChatCompletionService chatCompletion chatHistory.AddUserMessage("Hi, I'm looking for book suggestions"); OutputLastMessage(chatHistory); - // First bot assistant message + // First assistant message await StreamMessageOutputAsync(chatCompletionService, chatHistory, AuthorRole.Assistant); // Second user message chatHistory.AddUserMessage("I love history and philosophy, I'd like to learn something new about Greece, any suggestion?"); OutputLastMessage(chatHistory); - // Second bot assistant message + // Second assistant message await StreamMessageOutputAsync(chatCompletionService, chatHistory, AuthorRole.Assistant); } diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index d417f12de6ea..25724b822243 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -48,6 +48,7 @@ + diff --git a/dotnet/samples/Demos/AIModelRouter/AIModelRouter.csproj b/dotnet/samples/Demos/AIModelRouter/AIModelRouter.csproj index 76964283d69c..4ce04e354cc8 100644 --- a/dotnet/samples/Demos/AIModelRouter/AIModelRouter.csproj +++ b/dotnet/samples/Demos/AIModelRouter/AIModelRouter.csproj @@ -15,6 +15,7 @@ + diff --git a/dotnet/samples/Demos/AIModelRouter/Program.cs b/dotnet/samples/Demos/AIModelRouter/Program.cs index 6b3183ef0bbb..9d3631dbcb90 100644 --- a/dotnet/samples/Demos/AIModelRouter/Program.cs +++ b/dotnet/samples/Demos/AIModelRouter/Program.cs @@ -36,22 +36,44 @@ private static async Task Main(string[] args) if (config["Ollama:ModelId"] is not null) { - services.AddOllamaChatCompletion(serviceId: "ollama", modelId: config["Ollama:ModelId"]!, endpoint: new Uri(config["Ollama:Endpoint"] ?? "http://localhost:11434")); + services.AddOllamaChatCompletion( + serviceId: "ollama", + modelId: config["Ollama:ModelId"]!, + endpoint: new Uri(config["Ollama:Endpoint"] ?? "http://localhost:11434")); + Console.WriteLine("• Ollama - Use \"ollama\" in the prompt."); } if (config["OpenAI:ApiKey"] is not null) { - services.AddOpenAIChatCompletion(serviceId: "openai", modelId: config["OpenAI:ModelId"] ?? "gpt-4o", apiKey: config["OpenAI:ApiKey"]!); + services.AddOpenAIChatCompletion( + serviceId: "openai", + modelId: config["OpenAI:ModelId"] ?? "gpt-4o", + apiKey: config["OpenAI:ApiKey"]!); + Console.WriteLine("• OpenAI Added - Use \"openai\" in the prompt."); } if (config["Onnx:ModelPath"] is not null) { - services.AddOnnxRuntimeGenAIChatCompletion(serviceId: "onnx", modelId: "phi-3", modelPath: config["Onnx:ModelPath"]!); + services.AddOnnxRuntimeGenAIChatCompletion( + serviceId: "onnx", + modelId: "phi-3", + modelPath: config["Onnx:ModelPath"]!); + Console.WriteLine("• ONNX Added - Use \"onnx\" in the prompt."); } + if (config["AzureAIInference:Endpoint"] is not null) + { + services.AddAzureAIInferenceChatCompletion( + serviceId: "azureai", + endpoint: new Uri(config["AzureAIInference:Endpoint"]!), + apiKey: config["AzureAIInference:ApiKey"]); + + Console.WriteLine("• Azure AI Inference Added - Use \"azureai\" in the prompt."); + } + // Adding a custom filter to capture router selected service id services.AddSingleton(new SelectedServiceFilter()); @@ -70,7 +92,7 @@ private static async Task Main(string[] args) // Find the best service to use based on the user's input KernelArguments arguments = new(new PromptExecutionSettings() { - ServiceId = router.FindService(userMessage, ["lmstudio", "ollama", "openai", "onnx"]) + ServiceId = router.FindService(userMessage, ["lmstudio", "ollama", "openai", "onnx", "azureai"]) }); // Invoke the prompt and print the response diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj new file mode 100644 index 000000000000..acf3f919710f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Connectors.AzureAIInference.UnitTests.csproj @@ -0,0 +1,48 @@ + + + + SemanticKernel.Connectors.AzureAIInference.UnitTests + $(AssemblyName) + net8.0 + true + enable + disable + false + $(NoWarn);CA2007,CA1806,CS1591,CA1869,CA1861,IDE0300,VSTHRD111,SKEXP0001,SKEXP0070 + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + + + + + + + Always + + + + diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Core/ChatClientCoreTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Core/ChatClientCoreTests.cs new file mode 100644 index 000000000000..d844ac784ba9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Core/ChatClientCoreTests.cs @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Azure; +using Azure.AI.Inference; +using Azure.Core; +using Azure.Core.Pipeline; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel.Connectors.AzureAIInference.Core; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.Services; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Core; + +public sealed class ChatClientCoreTests +{ + private readonly Uri _endpoint = new("http://localhost"); + + [Fact] + public void ItCanBeInstantiatedAndPropertiesSetAsExpected() + { + // Arrange + var logger = new Mock>().Object; + var breakingGlassClient = new ChatCompletionsClient(this._endpoint, new AzureKeyCredential("key")); + + // Act + var clientCoreModelConstructor = new ChatClientCore("model1", "apiKey", this._endpoint); + var clientCoreBreakingGlassConstructor = new ChatClientCore("model1", breakingGlassClient, logger: logger); + + // Assert + Assert.Equal("model1", clientCoreModelConstructor.ModelId); + Assert.Equal("model1", clientCoreBreakingGlassConstructor.ModelId); + + Assert.NotNull(clientCoreModelConstructor.Client); + Assert.NotNull(clientCoreBreakingGlassConstructor.Client); + Assert.Equal(breakingGlassClient, clientCoreBreakingGlassConstructor.Client); + Assert.Equal(NullLogger.Instance, clientCoreModelConstructor.Logger); + Assert.Equal(logger, clientCoreBreakingGlassConstructor.Logger); + } + + [Theory] + [InlineData("http://localhost", null)] + [InlineData(null, "http://localhost")] + [InlineData("http://localhost-1", "http://localhost-2")] + public void ItUsesEndpointAsExpected(string? clientBaseAddress, string? providedEndpoint) + { + // Arrange + Uri? endpoint = null; + HttpClient? client = null; + if (providedEndpoint is not null) + { + endpoint = new Uri(providedEndpoint); + } + + if (clientBaseAddress is not null) + { + client = new HttpClient { BaseAddress = new Uri(clientBaseAddress) }; + } + + // Act + var clientCore = new ChatClientCore("model", "apiKey", endpoint: endpoint, httpClient: client); + + // Assert + Assert.Equal(endpoint ?? client?.BaseAddress ?? new Uri("https://api.openai.com/v1"), clientCore.Endpoint); + + Assert.True(clientCore.Attributes.ContainsKey(AIServiceExtensions.EndpointKey)); + Assert.Equal(endpoint?.ToString() ?? client?.BaseAddress?.ToString(), clientCore.Attributes[AIServiceExtensions.EndpointKey]); + + client?.Dispose(); + } + + [Fact] + public void ItThrowsIfNoEndpointOptionIsProvided() + { + // Act & Assert + Assert.Throws(() => new ChatClientCore("model", "apiKey", endpoint: null, httpClient: null)); + } + + [Fact] + public async Task ItAddSemanticKernelHeadersOnEachRequestAsync() + { + // Arrange + using HttpMessageHandlerStub handler = new(); + using HttpClient httpClient = new(handler); + httpClient.BaseAddress = this._endpoint; + handler.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK); + + var clientCore = new ChatClientCore(modelId: "model", apiKey: "test", httpClient: httpClient); + + var pipelineMessage = clientCore.Client!.Pipeline.CreateMessage(); + pipelineMessage.Request.Method = RequestMethod.Post; + pipelineMessage.Request.Uri = new RequestUriBuilder() { Host = "localhost", Scheme = "https" }; + pipelineMessage.Request.Content = RequestContent.Create(new BinaryData("test")); + + // Act + await clientCore.Client.Pipeline.SendAsync(pipelineMessage, CancellationToken.None); + + // Assert + Assert.True(handler.RequestHeaders!.Contains(HttpHeaderConstant.Names.SemanticKernelVersion)); + Assert.Equal(HttpHeaderConstant.Values.GetAssemblyVersion(typeof(ChatClientCore)), handler.RequestHeaders.GetValues(HttpHeaderConstant.Names.SemanticKernelVersion).FirstOrDefault()); + + Assert.True(handler.RequestHeaders.Contains("User-Agent")); + Assert.Contains(HttpHeaderConstant.Values.UserAgent, handler.RequestHeaders.GetValues("User-Agent").FirstOrDefault()); + } + + [Fact] + public async Task ItDoesNotAddSemanticKernelHeadersWhenBreakingGlassClientIsProvidedAsync() + { + // Arrange + using HttpMessageHandlerStub handler = new(); + using HttpClient httpClient = new(handler); + handler.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK); + + var clientCore = new ChatClientCore( + modelId: "model", + chatClient: new ChatCompletionsClient(this._endpoint, new AzureKeyCredential("api-key"), + new ChatCompletionsClientOptions() + { + Transport = new HttpClientTransport(httpClient), + RetryPolicy = new RetryPolicy(maxRetries: 0), // Disable Azure SDK retry policy if and only if a custom HttpClient is provided. + Retry = { NetworkTimeout = Timeout.InfiniteTimeSpan } // Disable Azure SDK default timeout + })); + + var pipelineMessage = clientCore.Client!.Pipeline.CreateMessage(); + pipelineMessage.Request.Method = RequestMethod.Post; + pipelineMessage.Request.Uri = new RequestUriBuilder { Scheme = "http", Host = "http://localhost" }; + pipelineMessage.Request.Content = RequestContent.Create(new BinaryData("test")); + + // Act + await clientCore.Client.Pipeline.SendAsync(pipelineMessage, CancellationToken.None); + + // Assert + Assert.False(handler.RequestHeaders!.Contains(HttpHeaderConstant.Names.SemanticKernelVersion)); + Assert.DoesNotContain(HttpHeaderConstant.Values.UserAgent, handler.RequestHeaders.GetValues("User-Agent").FirstOrDefault()); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData("value")] + public void ItAddsAttributesButDoesNothingIfNullOrEmpty(string? value) + { + // Arrange + var clientCore = new ChatClientCore("model", "api-key", this._endpoint); + + // Act + clientCore.AddAttribute("key", value); + + // Assert + if (string.IsNullOrEmpty(value)) + { + Assert.False(clientCore.Attributes.ContainsKey("key")); + } + else + { + Assert.True(clientCore.Attributes.ContainsKey("key")); + Assert.Equal(value, clientCore.Attributes["key"]); + } + } + + [Fact] + public void ItAddsModelIdAttributeAsExpected() + { + // Arrange + var expectedModelId = "modelId"; + + // Act + var clientCore = new ChatClientCore(expectedModelId, "api-key", this._endpoint); + var clientCoreBreakingGlass = new ChatClientCore(expectedModelId, new ChatCompletionsClient(this._endpoint, new AzureKeyCredential(" "))); + + // Assert + Assert.True(clientCore.Attributes.ContainsKey(AIServiceExtensions.ModelIdKey)); + Assert.True(clientCoreBreakingGlass.Attributes.ContainsKey(AIServiceExtensions.ModelIdKey)); + Assert.Equal(expectedModelId, clientCore.Attributes[AIServiceExtensions.ModelIdKey]); + Assert.Equal(expectedModelId, clientCoreBreakingGlass.Attributes[AIServiceExtensions.ModelIdKey]); + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs new file mode 100644 index 000000000000..8d5b31548b5f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Azure; +using Azure.AI.Inference; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Extensions; +public sealed class AzureAIInferenceKernelBuilderExtensionsTests +{ + private readonly Uri _endpoint = new("https://endpoint"); + + [Theory] + [InlineData(InitializationType.ApiKey)] + [InlineData(InitializationType.BreakingGlassClientInline)] + [InlineData(InitializationType.BreakingGlassInServiceProvider)] + public void KernelBuilderAddAzureAIInferenceChatCompletionAddsValidService(InitializationType type) + { + // Arrange + var client = new ChatCompletionsClient(this._endpoint, new AzureKeyCredential("key")); + var builder = Kernel.CreateBuilder(); + + builder.Services.AddSingleton(client); + + // Act + builder = type switch + { + InitializationType.ApiKey => builder.AddAzureAIInferenceChatCompletion("model-id", "api-key", this._endpoint), + InitializationType.BreakingGlassClientInline => builder.AddAzureAIInferenceChatCompletion("model-id", client), + InitializationType.BreakingGlassInServiceProvider => builder.AddAzureAIInferenceChatCompletion("model-id", chatClient: null), + _ => builder + }; + + // Assert + var chatCompletionService = builder.Build().GetRequiredService(); + Assert.True(chatCompletionService is AzureAIInferenceChatCompletionService); + } + + public enum InitializationType + { + ApiKey, + BreakingGlassClientInline, + BreakingGlassInServiceProvider, + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..02b26f12921b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Azure; +using Azure.AI.Inference; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Extensions; + +public sealed class AzureAIInferenceServiceCollectionExtensionsTests +{ + private readonly Uri _endpoint = new("https://endpoint"); + + [Theory] + [InlineData(InitializationType.ApiKey)] + [InlineData(InitializationType.ClientInline)] + [InlineData(InitializationType.ClientInServiceProvider)] + public void ItCanAddChatCompletionService(InitializationType type) + { + // Arrange + var client = new ChatCompletionsClient(this._endpoint, new AzureKeyCredential("key")); + var builder = Kernel.CreateBuilder(); + + builder.Services.AddSingleton(client); + + // Act + IServiceCollection collection = type switch + { + InitializationType.ApiKey => builder.Services.AddAzureAIInferenceChatCompletion("modelId", "api-key", this._endpoint), + InitializationType.ClientInline => builder.Services.AddAzureAIInferenceChatCompletion("modelId", client), + InitializationType.ClientInServiceProvider => builder.Services.AddAzureAIInferenceChatCompletion("modelId", chatClient: null), + _ => builder.Services + }; + + // Assert + var chatCompletionService = builder.Build().GetRequiredService(); + Assert.True(chatCompletionService is AzureAIInferenceChatCompletionService); + } + + public enum InitializationType + { + ApiKey, + ClientInline, + ClientInServiceProvider, + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs new file mode 100644 index 000000000000..44bd2c006661 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs @@ -0,0 +1,280 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Azure; +using Azure.AI.Inference; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Services; + +/// +/// Tests for the class. +/// +public sealed class AzureAIInferenceChatCompletionServiceTests : IDisposable +{ + private readonly Uri _endpoint = new("https://localhost:1234"); + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly MultipleHttpMessageHandlerStub _multiMessageHandlerStub; + private readonly HttpClient _httpClient; + private readonly HttpClient _httpClientWithBaseAddress; + private readonly AzureAIInferencePromptExecutionSettings _executionSettings; + private readonly Mock _mockLoggerFactory; + private readonly ChatHistory _chatHistoryForTest = [new ChatMessageContent(AuthorRole.User, "test")]; + + public AzureAIInferenceChatCompletionServiceTests() + { + this._messageHandlerStub = new HttpMessageHandlerStub(); + this._multiMessageHandlerStub = new MultipleHttpMessageHandlerStub(); + this._httpClient = new HttpClient(this._messageHandlerStub, false); + this._httpClientWithBaseAddress = new HttpClient(this._messageHandlerStub, false) { BaseAddress = this._endpoint }; + this._mockLoggerFactory = new Mock(); + this._executionSettings = new AzureAIInferencePromptExecutionSettings(); + } + + /// + /// Checks that the constructors work as expected. + /// + [Fact] + public void ConstructorsWorksAsExpected() + { + // Arrange + using var httpClient = new HttpClient() { BaseAddress = this._endpoint }; + var loggerFactoryMock = new Mock(); + ChatCompletionsClient client = new(this._endpoint, new AzureKeyCredential("api-key")); + + // Act & Assert + // Endpoint constructor + new AzureAIInferenceChatCompletionService(endpoint: this._endpoint, apiKey: null); // Only the endpoint + new AzureAIInferenceChatCompletionService(httpClient: httpClient, apiKey: null); // Only the HttpClient with a BaseClass defined + new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null); // ModelId and endpoint + new AzureAIInferenceChatCompletionService(modelId: "model", apiKey: "api-key", endpoint: this._endpoint); // ModelId, apiKey, and endpoint + new AzureAIInferenceChatCompletionService(endpoint: this._endpoint, apiKey: null, loggerFactory: loggerFactoryMock.Object); // Endpoint and loggerFactory + + // Breaking Glass constructor + new AzureAIInferenceChatCompletionService(modelId: null, chatClient: client); // Client without model + new AzureAIInferenceChatCompletionService(modelId: "model", chatClient: client); // Client + new AzureAIInferenceChatCompletionService(modelId: "model", chatClient: client, loggerFactory: loggerFactoryMock.Object); // Client + } + + [Theory] + [InlineData("http://localhost:1234/chat/completions")] // Uses full path when provided + [InlineData("http://localhost:1234/v2/chat/completions")] // Uses full path when provided + [InlineData("http://localhost:1234")] + [InlineData("http://localhost:8080")] + [InlineData("https://something:8080")] // Accepts TLS Secured endpoints + [InlineData("http://localhost:1234/v2")] + [InlineData("http://localhost:8080/v2")] + public async Task ItUsesCustomEndpointsWhenProvidedDirectlyAsync(string endpoint) + { + // Arrange + var chatCompletion = new AzureAIInferenceChatCompletionService(modelId: "any", apiKey: null, httpClient: this._httpClient, endpoint: new Uri(endpoint)); + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { Content = this.CreateDefaultStringContent() }; + + // Act + await chatCompletion.GetChatMessageContentsAsync(this._chatHistoryForTest, this._executionSettings); + + // Assert + Assert.StartsWith($"{endpoint}/chat/completions", this._messageHandlerStub.RequestUri!.ToString()); + } + + [Theory] + [InlineData("http://localhost:1234/chat/completions")] // Uses full path when provided + [InlineData("http://localhost:1234/v2/chat/completions")] // Uses full path when provided + [InlineData("http://localhost:1234")] + [InlineData("http://localhost:8080")] + [InlineData("https://something:8080")] // Accepts TLS Secured endpoints + [InlineData("http://localhost:1234/v2")] + [InlineData("http://localhost:8080/v2")] + public async Task ItPrioritizesCustomEndpointOverHttpClientBaseAddressAsync(string endpoint) + { + // Arrange + this._httpClient.BaseAddress = new Uri("http://should-be-overridden"); + var chatCompletion = new AzureAIInferenceChatCompletionService(modelId: "any", apiKey: null, httpClient: this._httpClient, endpoint: new Uri(endpoint)); + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { Content = this.CreateDefaultStringContent() }; + + // Act + await chatCompletion.GetChatMessageContentsAsync(this._chatHistoryForTest, this._executionSettings); + + // Assert + Assert.StartsWith($"{endpoint}/chat/completions", this._messageHandlerStub.RequestUri!.ToString()); + } + + [Fact] + public async Task ItUsesHttpClientBaseAddressWhenNoEndpointIsProvidedAsync() + { + // Arrange + this._httpClient.BaseAddress = this._endpoint; + var chatCompletion = new AzureAIInferenceChatCompletionService(modelId: "any", httpClient: this._httpClient); + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK) + { Content = this.CreateDefaultStringContent() }; + + // Act + await chatCompletion.GetChatMessageContentsAsync(this._chatHistoryForTest, this._executionSettings); + + // Assert + Assert.StartsWith(this._endpoint.ToString(), this._messageHandlerStub.RequestUri?.ToString()); + } + + [Fact] + public void ItThrowsIfNoEndpointOrNoHttpClientBaseAddressIsProvided() + { + // Act & Assert + Assert.Throws(() => new AzureAIInferenceChatCompletionService(endpoint: null, httpClient: this._httpClient)); + } + + [Fact] + public async Task ItGetChatMessageContentsShouldHaveModelIdDefinedAsync() + { + // Arrange + var chatCompletion = new AzureAIInferenceChatCompletionService(apiKey: "NOKEY", httpClient: this._httpClientWithBaseAddress); + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { Content = this.CreateDefaultStringContent() }; + + var chatHistory = new ChatHistory(); + chatHistory.AddMessage(AuthorRole.User, "Hello"); + + // Act + var chatMessage = await chatCompletion.GetChatMessageContentAsync(chatHistory, this._executionSettings); + + // Assert + Assert.NotNull(chatMessage.ModelId); + Assert.Equal("phi3-medium-4k", chatMessage.ModelId); + } + + [Fact] + public async Task GetStreamingChatMessageContentsWorksCorrectlyAsync() + { + // Arrange + var service = new AzureAIInferenceChatCompletionService(httpClient: this._httpClientWithBaseAddress); + await using var stream = File.OpenRead("TestData/chat_completion_streaming_response.txt"); + + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StreamContent(stream) + }; + + // Act & Assert + var enumerator = service.GetStreamingChatMessageContentsAsync([]).GetAsyncEnumerator(); + + await enumerator.MoveNextAsync(); + Assert.Equal(AuthorRole.Assistant, enumerator.Current.Role); + + await enumerator.MoveNextAsync(); + Assert.Equal("Test content", enumerator.Current.Content); + Assert.Equal("stop", enumerator.Current.Metadata?["FinishReason"]); + } + + [Fact] + public async Task GetChatMessageContentsWithChatMessageContentItemCollectionCorrectlyAsync() + { + // Arrange + const string Prompt = "This is test prompt"; + const string AssistantMessage = "This is assistant message"; + const string CollectionItemPrompt = "This is collection item prompt"; + var chatCompletion = new AzureAIInferenceChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClientWithBaseAddress); + + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { Content = this.CreateDefaultStringContent() }; + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage(Prompt); + chatHistory.AddAssistantMessage(AssistantMessage); + chatHistory.AddUserMessage( + [ + new TextContent(CollectionItemPrompt), + new ImageContent(new Uri("https://image")) + ]); + + // Act + await chatCompletion.GetChatMessageContentsAsync(chatHistory); + + // Assert + var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(actualRequestContent); + var optionsJson = JsonSerializer.Deserialize(actualRequestContent); + + var messages = optionsJson.GetProperty("messages"); + + Assert.Equal(3, messages.GetArrayLength()); + + Assert.Equal(Prompt, messages[0].GetProperty("content").GetString()); + Assert.Equal("user", messages[0].GetProperty("role").GetString()); + + Assert.Equal(AssistantMessage, messages[1].GetProperty("content").GetString()); + Assert.Equal("assistant", messages[1].GetProperty("role").GetString()); + + var contentItems = messages[2].GetProperty("content"); + Assert.Equal(2, contentItems.GetArrayLength()); + Assert.Equal(CollectionItemPrompt, contentItems[0].GetProperty("text").GetString()); + Assert.Equal("text", contentItems[0].GetProperty("type").GetString()); + Assert.Equal("https://image/", contentItems[1].GetProperty("image_url").GetProperty("url").GetString()); + Assert.Equal("image_url", contentItems[1].GetProperty("type").GetString()); + } + + [Theory] + [InlineData("string", "json_object")] + [InlineData("string", "text")] + [InlineData("string", "random")] + [InlineData("JsonElement.String", "\"json_object\"")] + [InlineData("JsonElement.String", "\"text\"")] + [InlineData("JsonElement.String", "\"random\"")] + [InlineData("ChatResponseFormat", "json_object")] + [InlineData("ChatResponseFormat", "text")] + public async Task GetChatMessageInResponseFormatsAsync(string formatType, string formatValue) + { + // Arrange + object? format = null; + switch (formatType) + { + case "string": + format = formatValue; + break; + case "JsonElement.String": + format = JsonSerializer.Deserialize(formatValue); + break; + case "ChatResponseFormat": + format = formatValue == "text" ? new ChatCompletionsResponseFormatText() : new ChatCompletionsResponseFormatJSON(); + break; + } + + var sut = new AzureAIInferenceChatCompletionService(httpClient: this._httpClientWithBaseAddress); + AzureAIInferencePromptExecutionSettings executionSettings = new() { ResponseFormat = format }; + + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(File.ReadAllText("TestData/chat_completion_response.json")) + }; + + // Act + var result = await sut.GetChatMessageContentAsync(this._chatHistoryForTest, executionSettings); + + // Assert + Assert.NotNull(result); + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._httpClientWithBaseAddress.Dispose(); + this._messageHandlerStub.Dispose(); + this._multiMessageHandlerStub.Dispose(); + } + + private StringContent CreateDefaultStringContent() + { + return new StringContent(File.ReadAllText("TestData/chat_completion_response.json")); + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Settings/AzureAIInferencePromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Settings/AzureAIInferencePromptExecutionSettingsTests.cs new file mode 100644 index 000000000000..c61a261e7d30 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Settings/AzureAIInferencePromptExecutionSettingsTests.cs @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; +using Xunit; + +namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Settings; +public sealed class AzureAIInferencePromptExecutionSettingsTests +{ + [Fact] + public void ItCreatesAzureAIInferenceExecutionSettingsWithCorrectDefaults() + { + // Arrange + // Act + AzureAIInferencePromptExecutionSettings executionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(null); + + // Assert + Assert.NotNull(executionSettings); + Assert.Null(executionSettings.Temperature); + Assert.Null(executionSettings.FrequencyPenalty); + Assert.Null(executionSettings.PresencePenalty); + Assert.Null(executionSettings.NucleusSamplingFactor); + Assert.Null(executionSettings.ResponseFormat); + Assert.Null(executionSettings.Seed); + Assert.Null(executionSettings.MaxTokens); + Assert.Empty(executionSettings.ExtensionData!); + Assert.Empty(executionSettings.Tools); + Assert.Empty(executionSettings.StopSequences!); + } + + [Fact] + public void ItUsesExistingAzureAIInferenceExecutionSettings() + { + // Arrange + AzureAIInferencePromptExecutionSettings actualSettings = new() + { + Temperature = 0.7f, + NucleusSamplingFactor = 0.7f, + FrequencyPenalty = 0.7f, + PresencePenalty = 0.7f, + StopSequences = ["foo", "bar"], + MaxTokens = 128 + }; + + // Act + AzureAIInferencePromptExecutionSettings executionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(actualSettings); + + // Assert + Assert.NotNull(executionSettings); + Assert.Equal(actualSettings, executionSettings); + Assert.Equal(128, executionSettings.MaxTokens); + } + + [Fact] + public void ItCanUseAzureAIInferenceExecutionSettings() + { + // Arrange + PromptExecutionSettings actualSettings = new() + { + ExtensionData = new Dictionary() { + { "max_tokens", 1000 }, + { "temperature", 0 } + } + }; + + // Act + AzureAIInferencePromptExecutionSettings executionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(actualSettings); + + // Assert + Assert.NotNull(executionSettings); + Assert.Equal(1000, executionSettings.MaxTokens); + Assert.Equal(0, executionSettings.Temperature); + } + + [Fact] + public void ItCreatesAzureAIInferenceExecutionSettingsFromExtraPropertiesSnakeCase() + { + // Arrange + PromptExecutionSettings actualSettings = new() + { + ExtensionData = new Dictionary() + { + { "temperature", 0.7 }, + { "top_p", 0.7 }, + { "frequency_penalty", 0.7 }, + { "presence_penalty", 0.7 }, + { "stop", new [] { "foo", "bar" } }, + { "max_tokens", 128 }, + { "seed", 123456 }, + } + }; + + // Act + AzureAIInferencePromptExecutionSettings executionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(actualSettings); + + // Assert + AssertExecutionSettings(executionSettings); + } + + [Fact] + public void ItCreatesAzureAIInferenceExecutionSettingsFromExtraPropertiesAsStrings() + { + // Arrange + PromptExecutionSettings actualSettings = new() + { + ExtensionData = new Dictionary() + { + { "temperature", 0.7 }, + { "top_p", "0.7" }, + { "frequency_penalty", "0.7" }, + { "presence_penalty", "0.7" }, + { "stop", new [] { "foo", "bar" } }, + { "max_tokens", "128" }, + { "response_format", "json" }, + { "seed", 123456 }, + } + }; + + // Act + AzureAIInferencePromptExecutionSettings executionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(actualSettings); + + // Assert + AssertExecutionSettings(executionSettings); + } + + [Fact] + public void ItCreatesAzureAIInferenceExecutionSettingsFromJsonSnakeCase() + { + // Arrange + var json = """ + { + "temperature": 0.7, + "top_p": 0.7, + "frequency_penalty": 0.7, + "presence_penalty": 0.7, + "stop": [ "foo", "bar" ], + "max_tokens": 128, + "response_format": "text", + "seed": 123456 + } + """; + var actualSettings = JsonSerializer.Deserialize(json); + + // Act + AzureAIInferencePromptExecutionSettings executionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(actualSettings); + + // Assert + AssertExecutionSettings(executionSettings); + } + + [Fact] + public void PromptExecutionSettingsCloneWorksAsExpected() + { + // Arrange + string configPayload = """ + { + "max_tokens": 60, + "temperature": 0.5, + "top_p": 0.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0 + } + """; + var executionSettings = JsonSerializer.Deserialize(configPayload); + + // Act + var clone = executionSettings!.Clone(); + + // Assert + Assert.NotNull(clone); + Assert.Equal(executionSettings.ModelId, clone.ModelId); + Assert.Equivalent(executionSettings.ExtensionData, clone.ExtensionData); + } + + [Fact] + public void PromptExecutionSettingsFreezeWorksAsExpected() + { + // Arrange + string configPayload = """ + { + "max_tokens": 60, + "temperature": 0.5, + "top_p": 0.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "response_format": "json", + "stop": [ "DONE" ] + } + """; + var executionSettings = JsonSerializer.Deserialize(configPayload)!; + executionSettings.ExtensionData = new Dictionary() { { "new", 5 } }; + + // Act + executionSettings!.Freeze(); + + // Assert + Assert.True(executionSettings.IsFrozen); + Assert.Throws(() => executionSettings.ModelId = "new-model"); + Assert.Throws(() => executionSettings.Temperature = 1); + Assert.Throws(() => executionSettings.FrequencyPenalty = 1); + Assert.Throws(() => executionSettings.PresencePenalty = 1); + Assert.Throws(() => executionSettings.NucleusSamplingFactor = 1); + Assert.Throws(() => executionSettings.MaxTokens = 100); + Assert.Throws(() => executionSettings.ResponseFormat = "text"); + Assert.Throws(() => executionSettings.StopSequences?.Add("STOP")); + Assert.Throws(() => executionSettings.ExtensionData["new"] = 6); + + executionSettings!.Freeze(); // idempotent + Assert.True(executionSettings.IsFrozen); + } + + [Fact] + public void FromExecutionSettingsWithDataDoesNotIncludeEmptyStopSequences() + { + // Arrange + PromptExecutionSettings settings = new AzureAIInferencePromptExecutionSettings { StopSequences = [] }; + + // Act + var executionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(settings); + + // Assert + Assert.NotNull(executionSettings.StopSequences); + Assert.Empty(executionSettings.StopSequences); + } + + private static void AssertExecutionSettings(AzureAIInferencePromptExecutionSettings executionSettings) + { + Assert.NotNull(executionSettings); + Assert.Equal(0.7f, executionSettings.Temperature); + Assert.Equal(0.7f, executionSettings.NucleusSamplingFactor); + Assert.Equal(0.7f, executionSettings.FrequencyPenalty); + Assert.Equal(0.7f, executionSettings.PresencePenalty); + Assert.Equal(["foo", "bar"], executionSettings.StopSequences); + Assert.Equal(128, executionSettings.MaxTokens); + Assert.Equal(123456, executionSettings.Seed); + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/TestData/chat_completion_response.json b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/TestData/chat_completion_response.json new file mode 100644 index 000000000000..c4b1198108fc --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/TestData/chat_completion_response.json @@ -0,0 +1,22 @@ +{ + "id": "chat-00078bf2c54346c6bfa31e561462c381", + "object": "chat.completion", + "created": 1723641172, + "model": "phi3-medium-4k", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Test response", + "tool_calls": [] + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 17, + "total_tokens": 148, + "completion_tokens": 131 + } +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/TestData/chat_completion_streaming_response.txt b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/TestData/chat_completion_streaming_response.txt new file mode 100644 index 000000000000..d3ef93e3b439 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/TestData/chat_completion_streaming_response.txt @@ -0,0 +1,7 @@ +data: {"id":"chat-6035afe96714485eb0998fe041bfdbdb","object":"chat.completion.chunk","created":1723641572,"model":"phi3-medium-4k","choices":[{"index":0,"delta":{"role":"assistant"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":17,"total_tokens":17,"completion_tokens":0}} + +data: {"id":"chat-6035afe96714485eb0998fe041bfdbdb","object":"chat.completion.chunk","created":1723641572,"model":"phi3-medium-4k","choices":[{"index":0,"delta":{"content":"Test content"},"logprobs":null,"finish_reason":"stop","stop_reason":32007}],"usage":{"prompt_tokens":17,"total_tokens":106,"completion_tokens":89}} + +data: {"id":"chat-6035afe96714485eb0998fe041bfdbdb","object":"chat.completion.chunk","created":1723641572,"model":"phi3-medium-4k","choices":[],"usage":{"prompt_tokens":17,"total_tokens":106,"completion_tokens":89}} + +data: [DONE] \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/AssemblyInfo.cs new file mode 100644 index 000000000000..fe66371dbc58 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +// This assembly is currently experimental. +[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj b/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj new file mode 100644 index 000000000000..2f87b005fda1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Connectors.AzureAIInference.csproj @@ -0,0 +1,34 @@ + + + + + Microsoft.SemanticKernel.Connectors.AzureAIInference + $(AssemblyName) + net8.0;netstandard2.0 + $(NoWarn);NU5104;SKEXP0001,SKEXP0070 + false + beta + + + + + + + + + Semantic Kernel - Azure AI Inference connectors + Semantic Kernel Model as a Service connectors for Azure AI Studio. Contains clients for chat completion, embeddings and text to image generation. + + + + + + + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Core/AddHeaderRequestPolicy.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Core/AddHeaderRequestPolicy.cs new file mode 100644 index 000000000000..f263e8dc1a27 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Core/AddHeaderRequestPolicy.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Microsoft.SemanticKernel.Connectors.AzureAIInference.Core; + +/// +/// Helper class to inject headers into Azure SDK HTTP pipeline +/// +internal sealed class AddHeaderRequestPolicy(string headerName, string headerValue) : HttpPipelineSynchronousPolicy +{ + private readonly string _headerName = headerName; + private readonly string _headerValue = headerValue; + + public override void OnSendingRequest(HttpMessage message) + { + message.Request.Headers.Add(this._headerName, this._headerValue); + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Core/ChatClientCore.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Core/ChatClientCore.cs new file mode 100644 index 000000000000..047bda6cabc8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Core/ChatClientCore.cs @@ -0,0 +1,649 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.Metrics; +using System.Linq; +using System.Net.Http; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure; +using Azure.AI.Inference; +using Azure.Core; +using Azure.Core.Pipeline; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Diagnostics; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.Services; + +#pragma warning disable CA2208 // Instantiate argument exceptions correctly + +namespace Microsoft.SemanticKernel.Connectors.AzureAIInference.Core; + +/// +/// Base class for AI clients that provides common functionality for interacting with Azure AI Inference services. +/// +internal sealed class ChatClientCore +{ + /// + /// Non-default endpoint for Azure AI Inference API. + /// + internal Uri? Endpoint { get; init; } + + /// + /// Non-default endpoint for Azure AI Inference API. + /// + internal string? ModelId { get; init; } + + /// + /// Logger instance + /// + internal ILogger Logger { get; init; } + + /// + /// Azure AI Inference Client + /// + internal ChatCompletionsClient Client { get; set; } + + /// + /// Storage for AI service attributes. + /// + internal Dictionary Attributes { get; } = []; + + /// + /// Initializes a new instance of the class. + /// + /// Optional target Model Id for endpoints that support multiple models + /// Azure AI Inference API Key. + /// Azure AI Inference compatible API endpoint. + /// Custom for HTTP requests. + /// The to use for logging. If null, no logging will be performed. + internal ChatClientCore( + string? modelId = null, + string? apiKey = null, + Uri? endpoint = null, + HttpClient? httpClient = null, + ILogger? logger = null) + { + this.Logger = logger ?? NullLogger.Instance; + // Accepts the endpoint if provided, otherwise uses the default Azure AI Inference endpoint. + this.Endpoint = endpoint ?? httpClient?.BaseAddress; + Verify.NotNull(this.Endpoint, "endpoint or base-address"); + this.AddAttribute(AIServiceExtensions.EndpointKey, this.Endpoint.ToString()); + + if (string.IsNullOrEmpty(apiKey)) + { + // Api Key is not required, when not provided will be set to single space to avoid empty exceptions from Azure SDK AzureKeyCredential type. + // This is a common scenario when using the Azure AI Inference service thru a Gateway that may inject the API Key. + apiKey = SingleSpace; + } + + if (!string.IsNullOrEmpty(modelId)) + { + this.ModelId = modelId; + this.AddAttribute(AIServiceExtensions.ModelIdKey, modelId); + } + + this.Client = new ChatCompletionsClient(this.Endpoint, new AzureKeyCredential(apiKey!), GetClientOptions(httpClient)); + } + + /// + /// Initializes a new instance of the class. + /// + /// Optional target Model Id for endpoints that support multiple models + /// Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// Azure AI Inference compatible API endpoint. + /// Custom for HTTP requests. + /// The to use for logging. If null, no logging will be performed. + internal ChatClientCore( + string? modelId = null, + TokenCredential? credential = null, + Uri? endpoint = null, + HttpClient? httpClient = null, + ILogger? logger = null) + { + Verify.NotNull(endpoint); + Verify.NotNull(credential); + this.Logger = logger ?? NullLogger.Instance; + + this.Endpoint = endpoint ?? httpClient?.BaseAddress; + Verify.NotNull(this.Endpoint, "endpoint or base-address"); + this.AddAttribute(AIServiceExtensions.EndpointKey, this.Endpoint.ToString()); + + if (!string.IsNullOrEmpty(modelId)) + { + this.ModelId = modelId; + this.AddAttribute(AIServiceExtensions.ModelIdKey, modelId); + } + + this.Client = new ChatCompletionsClient(this.Endpoint, credential, GetClientOptions(httpClient)); + } + + /// + /// Initializes a new instance of the class using the specified Azure AI Inference Client. + /// Note: instances created this way might not have the default diagnostics settings, + /// it's up to the caller to configure the client. + /// + /// Target Model Id for endpoints supporting more than one + /// Custom . + /// The to use for logging. If null, no logging will be performed. + internal ChatClientCore( + string? modelId, + ChatCompletionsClient chatClient, + ILogger? logger = null) + { + Verify.NotNull(chatClient); + if (!string.IsNullOrEmpty(modelId)) + { + this.ModelId = modelId; + this.AddAttribute(AIServiceExtensions.ModelIdKey, modelId); + } + + this.Logger = logger ?? NullLogger.Instance; + this.Client = chatClient; + } + + /// + /// Allows adding attributes to the client. + /// + /// Attribute key. + /// Attribute value. + internal void AddAttribute(string key, string? value) + { + if (!string.IsNullOrEmpty(value)) + { + this.Attributes.Add(key, value); + } + } + + /// + /// Get chat multiple chat content choices for the prompt and settings. + /// + /// + /// This should be used when the settings request for more than one choice. + /// + /// The chat history context. + /// The AI execution settings (optional). + /// The containing services, plugins, and other state for use throughout the operation. + /// The to monitor for cancellation requests. The default is . + /// List of different chat results generated by the remote model + internal async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(chatHistory); + + // Convert the incoming execution settings to specialized settings. + AzureAIInferencePromptExecutionSettings chatExecutionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(executionSettings); + + ValidateMaxTokens(chatExecutionSettings.MaxTokens); + + // Create the SDK ChatCompletionOptions instance from all available information. + ChatCompletionsOptions chatOptions = this.CreateChatCompletionsOptions(chatExecutionSettings, chatHistory, kernel, this.ModelId); + + // Make the request. + ChatCompletions? responseData = null; + var extraParameters = chatExecutionSettings.ExtraParameters; + + List responseContent; + using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.ModelId ?? string.Empty, ModelProvider, chatHistory, chatExecutionSettings)) + { + try + { + responseData = (await RunRequestAsync(() => this.Client!.CompleteAsync(chatOptions, chatExecutionSettings.ExtraParameters ?? string.Empty, cancellationToken)).ConfigureAwait(false)).Value; + + this.LogUsage(responseData.Usage); + if (responseData.Choices.Count == 0) + { + throw new KernelException("Chat completions not found"); + } + } + catch (Exception ex) when (activity is not null) + { + activity.SetError(ex); + if (responseData != null) + { + // Capture available metadata even if the operation failed. + activity + .SetResponseId(responseData.Id) + .SetPromptTokenUsage(responseData.Usage.PromptTokens) + .SetCompletionTokenUsage(responseData.Usage.CompletionTokens); + } + throw; + } + + responseContent = responseData.Choices.Select(chatChoice => this.GetChatMessage(chatChoice, responseData)).ToList(); + activity?.SetCompletionResponse(responseContent, responseData.Usage.PromptTokens, responseData.Usage.CompletionTokens); + } + + return responseContent; + } + + /// + /// Get streaming chat contents for the chat history provided using the specified settings. + /// + /// Throws if the specified type is not the same or fail to cast + /// The chat history to complete. + /// The AI execution settings (optional). + /// The containing services, plugins, and other state for use throughout the operation. + /// The to monitor for cancellation requests. The default is . + /// Streaming list of different completion streaming string updates generated by the remote model + internal async IAsyncEnumerable GetStreamingChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(chatHistory); + + AzureAIInferencePromptExecutionSettings chatExecutionSettings = AzureAIInferencePromptExecutionSettings.FromExecutionSettings(executionSettings); + + ValidateMaxTokens(chatExecutionSettings.MaxTokens); + + var chatOptions = this.CreateChatCompletionsOptions(chatExecutionSettings, chatHistory, kernel, this.ModelId); + StringBuilder? contentBuilder = null; + + // Reset state + contentBuilder?.Clear(); + + // Stream the response. + IReadOnlyDictionary? metadata = null; + string? streamedName = null; + ChatRole? streamedRole = default; + CompletionsFinishReason finishReason = default; + + using var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.ModelId ?? string.Empty, ModelProvider, chatHistory, chatExecutionSettings); + StreamingResponse response; + try + { + response = await RunRequestAsync(() => this.Client.CompleteStreamingAsync(chatOptions, cancellationToken)).ConfigureAwait(false); + } + catch (Exception ex) when (activity is not null) + { + activity.SetError(ex); + throw; + } + + var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator(); + List? streamedContents = activity is not null ? [] : null; + try + { + while (true) + { + try + { + if (!await responseEnumerator.MoveNextAsync()) + { + break; + } + } + catch (Exception ex) when (activity is not null) + { + activity.SetError(ex); + throw; + } + + StreamingChatCompletionsUpdate update = responseEnumerator.Current; + metadata = GetResponseMetadata(update); + streamedRole ??= update.Role; + streamedName ??= update.AuthorName; + finishReason = update.FinishReason ?? default; + + AuthorRole? role = null; + if (streamedRole.HasValue) + { + role = new AuthorRole(streamedRole.Value.ToString()); + } + + StreamingChatMessageContent streamingChatMessageContent = + new(role: update.Role.HasValue ? new AuthorRole(update.Role.ToString()!) : null, content: update.ContentUpdate, innerContent: update, modelId: update.Model, metadata: metadata) + { + AuthorName = streamedName, + Role = role, + Metadata = metadata, + }; + + streamedContents?.Add(streamingChatMessageContent); + yield return streamingChatMessageContent; + } + } + finally + { + activity?.EndStreaming(streamedContents, null); + await responseEnumerator.DisposeAsync(); + } + } + + #region Private + + private const string ModelProvider = "azure-ai-inference"; + /// + /// Instance of for metrics. + /// + private static readonly Meter s_meter = new("Microsoft.SemanticKernel.Connectors.AzureAIInference"); + + /// + /// Instance of to keep track of the number of prompt tokens used. + /// + private static readonly Counter s_promptTokensCounter = + s_meter.CreateCounter( + name: "semantic_kernel.connectors.azure-ai-inference.tokens.prompt", + unit: "{token}", + description: "Number of prompt tokens used"); + + /// + /// Instance of to keep track of the number of completion tokens used. + /// + private static readonly Counter s_completionTokensCounter = + s_meter.CreateCounter( + name: "semantic_kernel.connectors.azure-ai-inference.tokens.completion", + unit: "{token}", + description: "Number of completion tokens used"); + + /// + /// Instance of to keep track of the total number of tokens used. + /// + private static readonly Counter s_totalTokensCounter = + s_meter.CreateCounter( + name: "semantic_kernel.connectors.azure-ai-inference.tokens.total", + unit: "{token}", + description: "Number of tokens used"); + + /// + /// Single space constant. + /// + private const string SingleSpace = " "; + + /// Gets options to use for an Azure AI InferenceClient + /// Custom for HTTP requests. + /// Optional API version. + /// An instance of . + private static ChatCompletionsClientOptions GetClientOptions(HttpClient? httpClient, ChatCompletionsClientOptions.ServiceVersion? serviceVersion = null) + { + ChatCompletionsClientOptions options = serviceVersion is not null ? + new(serviceVersion.Value) : + new(); + + options.Diagnostics.ApplicationId = HttpHeaderConstant.Values.UserAgent; + + options.AddPolicy(new AddHeaderRequestPolicy(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(ChatClientCore))), Azure.Core.HttpPipelinePosition.PerCall); + + if (httpClient is not null) + { + options.Transport = new HttpClientTransport(httpClient); + options.RetryPolicy = new RetryPolicy(maxRetries: 0); // Disable retry policy if and only if a custom HttpClient is provided. + options.Retry.NetworkTimeout = Timeout.InfiniteTimeSpan; // Disable default timeout + } + + return options; + } + + /// + /// Invokes the specified request and handles exceptions. + /// + /// Type of the response. + /// Request to invoke. + /// Returns the response. + private static async Task RunRequestAsync(Func> request) + { + try + { + return await request.Invoke().ConfigureAwait(false); + } + catch (RequestFailedException e) + { + throw e.ToHttpOperationException(); + } + } + + /// + /// Checks if the maximum tokens value is valid. + /// + /// Maximum tokens value. + /// Throws if the maximum tokens value is invalid. + private static void ValidateMaxTokens(int? maxTokens) + { + if (maxTokens.HasValue && maxTokens < 1) + { + throw new ArgumentException($"MaxTokens {maxTokens} is not valid, the value must be greater than zero"); + } + } + + /// + /// Creates a new instance of based on the provided settings. + /// + /// The execution settings. + /// The chat history. + /// Kernel instance. + /// Model ID. + /// Create a new instance of . + private ChatCompletionsOptions CreateChatCompletionsOptions( + AzureAIInferencePromptExecutionSettings executionSettings, + ChatHistory chatHistory, + Kernel? kernel, + string? modelId) + { + if (this.Logger.IsEnabled(LogLevel.Trace)) + { + this.Logger.LogTrace("ChatHistory: {ChatHistory}, Settings: {Settings}", + JsonSerializer.Serialize(chatHistory), + JsonSerializer.Serialize(executionSettings)); + } + + var options = new ChatCompletionsOptions + { + MaxTokens = executionSettings.MaxTokens, + Temperature = executionSettings.Temperature, + NucleusSamplingFactor = executionSettings.NucleusSamplingFactor, + FrequencyPenalty = executionSettings.FrequencyPenalty, + PresencePenalty = executionSettings.PresencePenalty, + Model = modelId, + Seed = executionSettings.Seed, + }; + + switch (executionSettings.ResponseFormat) + { + case ChatCompletionsResponseFormat formatObject: + // If the response format is an Azure SDK ChatCompletionsResponseFormat, just pass it along. + options.ResponseFormat = formatObject; + break; + + case string formatString: + // If the response format is a string, map the ones we know about, and ignore the rest. + switch (formatString) + { + case "json_object": + options.ResponseFormat = new ChatCompletionsResponseFormatJSON(); + break; + + case "text": + options.ResponseFormat = new ChatCompletionsResponseFormatText(); + break; + } + break; + + case JsonElement formatElement: + // This is a workaround for a type mismatch when deserializing a JSON into an object? type property. + // Handling only string formatElement. + if (formatElement.ValueKind == JsonValueKind.String) + { + string formatString = formatElement.GetString() ?? ""; + switch (formatString) + { + case "json_object": + options.ResponseFormat = new ChatCompletionsResponseFormatJSON(); + break; + + case "text": + options.ResponseFormat = new ChatCompletionsResponseFormatText(); + break; + } + } + break; + } + + if (executionSettings.StopSequences is { Count: > 0 }) + { + foreach (var s in executionSettings.StopSequences) + { + options.StopSequences.Add(s); + } + } + + foreach (var message in chatHistory) + { + options.Messages.AddRange(GetRequestMessages(message)); + } + + return options; + } + + /// + /// Create request messages based on the chat message content. + /// + /// Chat message content. + /// A list of . + /// When the message role is not supported. + private static List GetRequestMessages(ChatMessageContent message) + { + if (message.Role == AuthorRole.System) + { + return [new ChatRequestSystemMessage(message.Content)]; + } + + if (message.Role == AuthorRole.User) + { + if (message.Items is { Count: 1 } && message.Items.FirstOrDefault() is TextContent textContent) + { + // Name removed temporarily as the Azure AI Inference service does not support it ATM. + // Issue: https://github.com/Azure/azure-sdk-for-net/issues/45415 + return [new ChatRequestUserMessage(textContent.Text) /*{ Name = message.AuthorName }*/ ]; + } + + return [new ChatRequestUserMessage(message.Items.Select(static (KernelContent item) => (ChatMessageContentItem)(item switch + { + TextContent textContent => new ChatMessageTextContentItem(textContent.Text), + ImageContent imageContent => GetImageContentItem(imageContent), + _ => throw new NotSupportedException($"Unsupported chat message content type '{item.GetType()}'.") + }))) + + // Name removed temporarily as the Azure AI Inference service does not support it ATM. + // Issue: https://github.com/Azure/azure-sdk-for-net/issues/45415 + /*{ Name = message.AuthorName }*/]; + } + + if (message.Role == AuthorRole.Assistant) + { + // Name removed temporarily as the Azure AI Inference service does not support it ATM. + // Issue: https://github.com/Azure/azure-sdk-for-net/issues/45415 + return [new ChatRequestAssistantMessage() { Content = message.Content /* Name = message.AuthorName */ }]; + } + + throw new NotSupportedException($"Role {message.Role} is not supported."); + } + + /// + /// Create a new instance of based on the provided + /// + /// Target . + /// new instance of + /// When the does not have Data or Uri. + private static ChatMessageImageContentItem GetImageContentItem(ImageContent imageContent) + { + if (imageContent.Data is { IsEmpty: false } data) + { + return new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MimeType); + } + + if (imageContent.Uri is not null) + { + return new ChatMessageImageContentItem(imageContent.Uri); + } + + throw new ArgumentException($"{nameof(ImageContent)} must have either Data or a Uri."); + } + + /// + /// Captures usage details, including token information. + /// + /// Instance of with usage details. + private void LogUsage(CompletionsUsage usage) + { + if (usage is null) + { + this.Logger.LogDebug("Token usage information unavailable."); + return; + } + + if (this.Logger.IsEnabled(LogLevel.Information)) + { + this.Logger.LogInformation( + "Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}.", + usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens); + } + + s_promptTokensCounter.Add(usage.PromptTokens); + s_completionTokensCounter.Add(usage.CompletionTokens); + s_totalTokensCounter.Add(usage.TotalTokens); + } + + /// + /// Create a new based on the provided and . + /// + /// The object representing the selected choice. + /// The object containing the response data. + /// A new object. + private ChatMessageContent GetChatMessage(ChatChoice chatChoice, ChatCompletions responseData) + { + var message = new ChatMessageContent( + new AuthorRole(chatChoice.Message.Role.ToString()), + chatChoice.Message.Content, + responseData.Model, + innerContent: responseData, + metadata: GetChatChoiceMetadata(responseData, chatChoice) + ); + return message; + } + + /// + /// Create the metadata dictionary based on the provided and . + /// + /// The object containing the response data. + /// The object representing the selected choice. + /// A new dictionary with metadata. + private static Dictionary GetChatChoiceMetadata(ChatCompletions completions, ChatChoice chatChoice) + { + return new Dictionary(5) + { + { nameof(completions.Id), completions.Id }, + { nameof(completions.Created), completions.Created }, + { nameof(completions.Usage), completions.Usage }, + + // Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it. + { nameof(chatChoice.FinishReason), chatChoice.FinishReason?.ToString() }, + { nameof(chatChoice.Index), chatChoice.Index }, + }; + } + + /// + /// Create the metadata dictionary based on the provided . + /// + /// The object containing the response data. + /// A new dictionary with metadata. + private static Dictionary GetResponseMetadata(StreamingChatCompletionsUpdate completions) + { + return new Dictionary(3) + { + { nameof(completions.Id), completions.Id }, + { nameof(completions.Created), completions.Created }, + + // Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it. + { nameof(completions.FinishReason), completions.FinishReason?.ToString() }, + }; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Core/RequestFailedExceptionExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Core/RequestFailedExceptionExtensions.cs new file mode 100644 index 000000000000..37d5890da116 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Core/RequestFailedExceptionExtensions.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Net; +using Azure; + +namespace Microsoft.SemanticKernel.Connectors.AzureAIInference; + +/// +/// Provides extension methods for the class. +/// +internal static class RequestFailedExceptionExtensions +{ + /// + /// Converts a to an . + /// + /// The original . + /// An instance. + public static HttpOperationException ToHttpOperationException(this RequestFailedException exception) + { + const int NoResponseReceived = 0; + + string? responseContent = null; + + try + { + responseContent = exception.GetRawResponse()?.Content?.ToString(); + } +#pragma warning disable CA1031 // Do not catch general exception types + catch { } // We want to suppress any exceptions that occur while reading the content, ensuring that an HttpOperationException is thrown instead. +#pragma warning restore CA1031 + + return new HttpOperationException( + exception.Status == NoResponseReceived ? null : (HttpStatusCode?)exception.Status, + responseContent, + exception.Message, + exception); + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs new file mode 100644 index 000000000000..c1760d4ac316 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Azure.AI.Inference; +using Azure.Core; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; + +namespace Microsoft.SemanticKernel; + +/// +/// Provides extension methods for to configure Azure AI Inference connectors. +/// +public static class AzureAIInferenceKernelBuilderExtensions +{ + /// + /// Adds the to the . + /// + /// The instance to augment. + /// Target Model Id for endpoints supporting more than one model + /// API Key + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// A local identifier for the given AI service + /// The same instance as . + public static IKernelBuilder AddAzureAIInferenceChatCompletion( + this IKernelBuilder builder, + string? modelId = null, + string? apiKey = null, + Uri? endpoint = null, + HttpClient? httpClient = null, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddAzureAIInferenceChatCompletion(modelId, apiKey, endpoint, httpClient, serviceId); + + return builder; + } + + /// + /// Adds the to the . + /// + /// The instance to augment. + /// Target Model Id for endpoints supporting more than one model + /// Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// A local identifier for the given AI service + /// The same instance as . + public static IKernelBuilder AddAzureAIInferenceChatCompletion( + this IKernelBuilder builder, + string? modelId, + TokenCredential credential, + Uri? endpoint = null, + HttpClient? httpClient = null, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddAzureAIInferenceChatCompletion(modelId, credential, endpoint, httpClient, serviceId); + + return builder; + } + + /// + /// Adds the to the . + /// + /// The instance to augment. + /// Azure AI Inference model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// The same instance as . + public static IKernelBuilder AddAzureAIInferenceChatCompletion( + this IKernelBuilder builder, + string modelId, + ChatCompletionsClient? chatClient = null, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddAzureAIInferenceChatCompletion(modelId, chatClient, serviceId); + + return builder; + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs new file mode 100644 index 000000000000..b508b38537d3 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Azure.AI.Inference; +using Azure.Core; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; +using Microsoft.SemanticKernel.Http; + +namespace Microsoft.SemanticKernel; + +/// +/// Provides extension methods for to configure Azure AI Inference connectors. +/// +public static class AzureAIInferenceServiceCollectionExtensions +{ + /// + /// Adds the to the . + /// + /// The instance to augment. + /// Target Model Id for endpoints supporting more than one model + /// API Key + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// A local identifier for the given AI service + /// The same instance as . + public static IServiceCollection AddAzureAIInferenceChatCompletion( + this IServiceCollection services, + string? modelId = null, + string? apiKey = null, + Uri? endpoint = null, + HttpClient? httpClient = null, + string? serviceId = null) + { + Verify.NotNull(services); + + AzureAIInferenceChatCompletionService Factory(IServiceProvider serviceProvider, object? _) => + new(modelId, + apiKey, + endpoint, + HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + serviceProvider.GetService()); + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Adds the to the . + /// + /// The instance to augment. + /// Target Model Id for endpoints supporting more than one model + /// Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// A local identifier for the given AI service + /// The same instance as . + public static IServiceCollection AddAzureAIInferenceChatCompletion( + this IServiceCollection services, + string? modelId, + TokenCredential credential, + Uri? endpoint = null, + HttpClient? httpClient = null, + string? serviceId = null) + { + Verify.NotNull(services); + + AzureAIInferenceChatCompletionService Factory(IServiceProvider serviceProvider, object? _) => + new(modelId, + credential, + endpoint, + HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + serviceProvider.GetService()); + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Adds the to the . + /// + /// The instance to augment. + /// Azure AI Inference model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// The same instance as . + public static IServiceCollection AddAzureAIInferenceChatCompletion(this IServiceCollection services, + string modelId, + ChatCompletionsClient? chatClient = null, + string? serviceId = null) + { + Verify.NotNull(services); + + AzureAIInferenceChatCompletionService Factory(IServiceProvider serviceProvider, object? _) => + new(modelId, chatClient ?? serviceProvider.GetRequiredService(), serviceProvider.GetService()); + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs new file mode 100644 index 000000000000..0b55ac3cd696 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Azure.AI.Inference; +using Azure.Core; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureAIInference.Core; + +namespace Microsoft.SemanticKernel.Connectors.AzureAIInference; + +/// +/// Chat completion service for Azure AI Inference. +/// +public sealed class AzureAIInferenceChatCompletionService : IChatCompletionService +{ + private readonly ChatClientCore _core; + + /// + /// Initializes a new instance of the class. + /// + /// Target Model Id for endpoints supporting more than one model + /// API Key + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// The to use for logging. If null, no logging will be performed. + public AzureAIInferenceChatCompletionService( + string? modelId = null, + string? apiKey = null, + Uri? endpoint = null, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + { + this._core = new( + modelId, + apiKey, + endpoint, + httpClient, + loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService))); + } + + /// + /// Initializes a new instance of the class. + /// + /// Target Model Id for endpoints supporting more than one model + /// Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// The to use for logging. If null, no logging will be performed. + public AzureAIInferenceChatCompletionService( + string? modelId, + TokenCredential credential, + Uri? endpoint = null, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + { + this._core = new( + modelId, + credential, + endpoint, + httpClient, + loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService))); + } + + /// + /// Initializes a new instance of the class providing your own ChatCompletionsClient instance. + /// + /// Target Model Id for endpoints supporting more than one model + /// Breaking glass for HTTP requests. + /// The to use for logging. If null, no logging will be performed. + public AzureAIInferenceChatCompletionService( + string? modelId, + ChatCompletionsClient chatClient, + ILoggerFactory? loggerFactory = null) + { + this._core = new( + modelId, + chatClient, + loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService))); + } + + /// + public IReadOnlyDictionary Attributes => this._core.Attributes; + + /// + public Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + => this._core.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); + + /// + public IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + => this._core.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); +} diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Settings/AzureAIInferencePromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Settings/AzureAIInferencePromptExecutionSettings.cs new file mode 100644 index 000000000000..db502f3ebf4d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Settings/AzureAIInferencePromptExecutionSettings.cs @@ -0,0 +1,281 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text.Json; +using System.Text.Json.Serialization; +using Azure.AI.Inference; +using Microsoft.SemanticKernel.Text; + +namespace Microsoft.SemanticKernel.Connectors.AzureAIInference; + +/// +/// Chat completion prompt execution settings. +/// +[JsonNumberHandling(JsonNumberHandling.AllowReadingFromString)] +public sealed class AzureAIInferencePromptExecutionSettings : PromptExecutionSettings +{ + /// + /// Initializes a new instance of the class. + /// + public AzureAIInferencePromptExecutionSettings() + { + this.ExtensionData = new Dictionary(); + } + + /// + /// Allowed values: "error" | "drop" | "pass-through" + /// + [JsonPropertyName("extra_parameters")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ExtraParameters + { + get => this._extraParameters; + set + { + this.ThrowIfFrozen(); + this._extraParameters = value; + } + } + + /// + /// A value that influences the probability of generated tokens appearing based on their cumulative + /// frequency in generated text. + /// Positive values will make tokens less likely to appear as their frequency increases and + /// decrease the likelihood of the model repeating the same statements verbatim. + /// Supported range is [-2, 2]. + /// + [JsonPropertyName("frequency_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? FrequencyPenalty + { + get => this._frequencyPenalty; + set + { + this.ThrowIfFrozen(); + this._frequencyPenalty = value; + } + } + + /// + /// A value that influences the probability of generated tokens appearing based on their existing + /// presence in generated text. + /// Positive values will make tokens less likely to appear when they already exist and increase the + /// model's likelihood to output new topics. + /// Supported range is [-2, 2]. + /// + [JsonPropertyName("presence_penalty")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? PresencePenalty + { + get => this._presencePenalty; + set + { + this.ThrowIfFrozen(); + this._presencePenalty = value; + } + } + + /// + /// The sampling temperature to use that controls the apparent creativity of generated completions. + /// Higher values will make output more random while lower values will make results more focused + /// and deterministic. + /// It is not recommended to modify temperature and top_p for the same completions request as the + /// interaction of these two settings is difficult to predict. + /// Supported range is [0, 1]. + /// + [JsonPropertyName("temperature")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? Temperature + { + get => this._temperature; + set + { + this.ThrowIfFrozen(); + this._temperature = value; + } + } + + /// + /// An alternative to sampling with temperature called nucleus sampling. This value causes the + /// model to consider the results of tokens with the provided probability mass. As an example, a + /// value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + /// considered. + /// It is not recommended to modify temperature and top_p for the same completions request as the + /// interaction of these two settings is difficult to predict. + /// Supported range is [0, 1]. + /// + [JsonPropertyName("top_p")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? NucleusSamplingFactor + { + get => this._nucleusSamplingFactor; + set + { + this.ThrowIfFrozen(); + this._nucleusSamplingFactor = value; + } + } + + /// The maximum number of tokens to generate. + [JsonPropertyName("max_tokens")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? MaxTokens + { + get => this._maxTokens; + set + { + this.ThrowIfFrozen(); + this._maxTokens = value; + } + } + + /// + /// The format that the model must output. Use this to enable JSON mode instead of the default text mode. + /// Note that to enable JSON mode, some AI models may also require you to instruct the model to produce JSON + /// via a system or user message. + /// Please note is the base class. According to the scenario, a derived class of the base class might need to be assigned here, or this property needs to be casted to one of the possible derived classes. + /// The available derived classes include and . + /// + [JsonPropertyName("response_format")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public object? ResponseFormat + { + get => this._responseFormat; + set + { + this.ThrowIfFrozen(); + this._responseFormat = value; + } + } + + /// A collection of textual sequences that will end completions generation. + [JsonPropertyName("stop")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public IList StopSequences + { + get => this._stopSequences; + set + { + this.ThrowIfFrozen(); + this._stopSequences = value; + } + } + + /// + /// The available tool definitions that the chat completions request can use, including caller-defined functions. + /// Please note is the base class. According to the scenario, a derived class of the base class might need to be assigned here, or this property needs to be casted to one of the possible derived classes. + /// The available derived classes include . + /// + [JsonPropertyName("tools")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public IList Tools + { + get => this._tools; + set + { + this.ThrowIfFrozen(); + this._tools = value; + } + } + + /// + /// If specified, the system will make a best effort to sample deterministically such that repeated requests with the + /// same seed and parameters should return the same result. Determinism is not guaranteed. + /// + [JsonPropertyName("seed")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public long? Seed + { + get => this._seed; + set + { + this.ThrowIfFrozen(); + this._seed = value; + } + } + + /// + public override void Freeze() + { + if (this.IsFrozen) + { + return; + } + + base.Freeze(); + + if (this._stopSequences is not null) + { + this._stopSequences = new ReadOnlyCollection(this._stopSequences); + } + + if (this._tools is not null) + { + this._tools = new ReadOnlyCollection(this._tools); + } + } + + /// + public override PromptExecutionSettings Clone() + { + return new AzureAIInferencePromptExecutionSettings() + { + ExtraParameters = this.ExtraParameters, + FrequencyPenalty = this.FrequencyPenalty, + PresencePenalty = this.PresencePenalty, + Temperature = this.Temperature, + NucleusSamplingFactor = this.NucleusSamplingFactor, + MaxTokens = this.MaxTokens, + ResponseFormat = this.ResponseFormat, + StopSequences = new List(this.StopSequences), + Tools = new List(this.Tools), + Seed = this.Seed, + ExtensionData = this.ExtensionData is not null ? new Dictionary(this.ExtensionData) : null, + }; + } + + /// + /// Create a new settings object with the values from another settings object. + /// + /// Template configuration + /// An instance of + public static AzureAIInferencePromptExecutionSettings FromExecutionSettings(PromptExecutionSettings? executionSettings) + { + if (executionSettings is null) + { + return new AzureAIInferencePromptExecutionSettings(); + } + + if (executionSettings is AzureAIInferencePromptExecutionSettings settings) + { + return settings; + } + + var json = JsonSerializer.Serialize(executionSettings); + + var aiInferenceSettings = JsonSerializer.Deserialize(json, JsonOptionsCache.ReadPermissive); + if (aiInferenceSettings is not null) + { + return aiInferenceSettings; + } + + throw new ArgumentException($"Invalid execution settings, cannot convert to {nameof(AzureAIInferencePromptExecutionSettings)}", nameof(executionSettings)); + } + + #region private ================================================================================ + + private string? _extraParameters; + private float? _frequencyPenalty; + private float? _presencePenalty; + private float? _temperature; + private float? _nucleusSamplingFactor; + private int? _maxTokens; + private object? _responseFormat; + private IList _stopSequences = []; + private IList _tools = []; + private long? _seed; + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/ClientCoreTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/ClientCoreTests.cs index f41b204058ed..017732f6d19c 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/ClientCoreTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/ClientCoreTests.cs @@ -135,7 +135,7 @@ public async Task ItAddSemanticKernelHeadersOnEachRequestAsync() } [Fact] - public async Task ItDoNotAddSemanticKernelHeadersWhenOpenAIClientIsProvidedAsync() + public async Task ItDoesNotAddSemanticKernelHeadersWhenOpenAIClientIsProvidedAsync() { using HttpMessageHandlerStub handler = new(); using HttpClient client = new(handler); @@ -169,7 +169,7 @@ public async Task ItDoNotAddSemanticKernelHeadersWhenOpenAIClientIsProvidedAsync [InlineData(null)] [InlineData("")] [InlineData("value")] - public void ItAddAttributesButDoesNothingIfNullOrEmpty(string? value) + public void ItAddsAttributesButDoesNothingIfNullOrEmpty(string? value) { // Arrange var clientCore = new ClientCore("model", "apikey"); @@ -190,7 +190,7 @@ public void ItAddAttributesButDoesNothingIfNullOrEmpty(string? value) } [Fact] - public void ItAddModelIdAttributeAsExpected() + public void ItAddsModelIdAttributeAsExpected() { // Arrange var expectedModelId = "modelId"; diff --git a/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletionServiceTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletionServiceTests.cs new file mode 100644 index 000000000000..140e16fc97cc --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatCompletionServiceTests.cs @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading.Tasks; +using Azure.AI.Inference; +using Azure.Identity; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Http.Resilience; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.AzureAIInference; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class AzureAIInferenceChatCompletionServiceTests(ITestOutputHelper output) : BaseIntegrationTest, IDisposable +{ + private const string InputParameterName = "input"; + private readonly XunitLogger _loggerFactory = new(output); + private readonly RedirectOutput _testOutputHelper = new(output); + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + [Theory] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task InvokeGetChatMessageContentsAsync(string prompt, string expectedAnswerContains) + { + // Arrange + var config = this._configuration.GetSection("AzureAIInference").Get(); + Assert.NotNull(config); + + var sut = (config.ApiKey is not null) + ? new AzureAIInferenceChatCompletionService( + endpoint: config.Endpoint, + apiKey: config.ApiKey, + loggerFactory: this._loggerFactory) + : new AzureAIInferenceChatCompletionService( + modelId: null, + endpoint: config.Endpoint, + credential: new AzureCliCredential(), + loggerFactory: this._loggerFactory); + + ChatHistory chatHistory = [ + new ChatMessageContent(AuthorRole.User, prompt) + ]; + + // Act + var result = await sut.GetChatMessageContentsAsync(chatHistory); + + // Assert + Assert.Single(result); + Assert.Contains(expectedAnswerContains, result[0].Content, StringComparison.OrdinalIgnoreCase); + } + + [Theory] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task InvokeGetStreamingChatMessageContentsAsync(string prompt, string expectedAnswerContains) + { + // Arrange + var config = this._configuration.GetSection("AzureAIInference").Get(); + Assert.NotNull(config); + + var sut = (config.ApiKey is not null) + ? new AzureAIInferenceChatCompletionService( + endpoint: config.Endpoint, + apiKey: config.ApiKey, + loggerFactory: this._loggerFactory) + : new AzureAIInferenceChatCompletionService( + modelId: null, + endpoint: config.Endpoint, + credential: new AzureCliCredential(), + loggerFactory: this._loggerFactory); + + ChatHistory chatHistory = [ + new ChatMessageContent(AuthorRole.User, prompt) + ]; + + StringBuilder fullContent = new(); + + // Act + await foreach (var update in sut.GetStreamingChatMessageContentsAsync(chatHistory)) + { + fullContent.Append(update.Content); + } + + // Assert + Assert.Contains(expectedAnswerContains, fullContent.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ItCanUseChatForTextGenerationAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var func = kernel.CreateFunctionFromPrompt( + "List the two planets after '{{$input}}', excluding moons, using bullet points.", + new AzureAIInferencePromptExecutionSettings()); + + // Act + var result = await func.InvokeAsync(kernel, new() { [InputParameterName] = "Jupiter" }); + + // Assert + Assert.NotNull(result); + Assert.Contains("Saturn", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task ItStreamingFromKernelTestAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + StringBuilder fullResult = new(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResult.Append(content); + } + + // Assert + Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ItHttpRetryPolicyTestAsync() + { + // Arrange + List statusCodes = []; + + var config = this._configuration.GetSection("AzureAIInference").Get(); + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); + + var kernelBuilder = Kernel.CreateBuilder(); + + kernelBuilder.AddAzureAIInferenceChatCompletion(endpoint: config.Endpoint, apiKey: null); + + kernelBuilder.Services.ConfigureHttpClientDefaults(c => + { + // Use a standard resiliency policy, augmented to retry on 401 Unauthorized for this example + c.AddStandardResilienceHandler().Configure(o => + { + o.Retry.ShouldHandle = args => ValueTask.FromResult(args.Outcome.Result?.StatusCode is HttpStatusCode.Unauthorized); + o.Retry.OnRetry = args => + { + statusCodes.Add(args.Outcome.Result?.StatusCode); + return ValueTask.CompletedTask; + }; + }); + }); + + var target = kernelBuilder.Build(); + + var plugins = TestHelpers.ImportSamplePlugins(target, "SummarizePlugin"); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + var exception = await Assert.ThrowsAsync(() => target.InvokeAsync(plugins["SummarizePlugin"]["Summarize"], new() { [InputParameterName] = prompt })); + + // Assert + Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s)); + Assert.Equal(HttpStatusCode.Unauthorized, ((HttpOperationException)exception).StatusCode); + } + + [Fact] + public async Task ItShouldReturnInnerContentAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin"); + + // Act + var result = await kernel.InvokeAsync(plugins["FunPlugin"]["Limerick"]); + var content = result.GetValue(); + // Assert + Assert.NotNull(content); + Assert.NotNull(content.InnerContent); + + Assert.IsType(content.InnerContent); + var completions = (ChatCompletions)content.InnerContent; + var usage = completions.Usage; + + // Usage + Assert.NotEqual(0, usage.PromptTokens); + Assert.NotEqual(0, usage.CompletionTokens); + } + + [Theory(Skip = "This test is for manual verification.")] + [InlineData("\n")] + [InlineData("\r\n")] + public async Task CompletionWithDifferentLineEndingsAsync(string lineEnding) + { + // Arrange + var prompt = + "Given a json input and a request. Apply the request on the json input and return the result. " + + $"Put the result in between tags{lineEnding}" + + $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name"""; + + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + // Act + FunctionResult actual = await kernel.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains("John", actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) + { + var config = this._configuration.GetSection("AzureAIInference").Get(); + Assert.NotNull(config); + Assert.NotNull(config.ApiKey); + Assert.NotNull(config.Endpoint); + + var kernelBuilder = base.CreateKernelBuilder(); + + kernelBuilder.AddAzureAIInferenceChatCompletion( + endpoint: config.Endpoint, + apiKey: config.ApiKey, + serviceId: config.ServiceId, + httpClient: httpClient); + + return kernelBuilder.Build(); + } + + public void Dispose() + { + this._loggerFactory.Dispose(); + this._testOutputHelper.Dispose(); + } +} diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 0ab7bcc04b90..cc3e121a8125 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -66,6 +66,7 @@ + diff --git a/dotnet/src/IntegrationTests/README.md b/dotnet/src/IntegrationTests/README.md index 1c646a824251..85a997bbea9a 100644 --- a/dotnet/src/IntegrationTests/README.md +++ b/dotnet/src/IntegrationTests/README.md @@ -38,6 +38,10 @@ dotnet user-secrets set "OpenAITextToImage:ServiceId" "dall-e-3" dotnet user-secrets set "OpenAITextToImage:ModelId" "dall-e-3" dotnet user-secrets set "OpenAITextToImage:ApiKey" "..." +dotnet user-secrets set "AzureAIInference:ServiceId" "azure-ai-inference" +dotnet user-secrets set "AzureAIInference:ApiKey" "..." +dotnet user-secrets set "AzureAIInference:Endpoint" "https://contoso.models.ai.azure.com/" + dotnet user-secrets set "AzureOpenAI:ServiceId" "azure-gpt-35-turbo-instruct" dotnet user-secrets set "AzureOpenAI:DeploymentName" "gpt-35-turbo-instruct" dotnet user-secrets set "AzureOpenAI:ChatDeploymentName" "gpt-4" diff --git a/dotnet/src/IntegrationTests/TestSettings/AzureAIInferenceConfiguration.cs b/dotnet/src/IntegrationTests/TestSettings/AzureAIInferenceConfiguration.cs new file mode 100644 index 000000000000..664effc9e3a5 --- /dev/null +++ b/dotnet/src/IntegrationTests/TestSettings/AzureAIInferenceConfiguration.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace SemanticKernel.IntegrationTests.TestSettings; + +[SuppressMessage("Performance", "CA1812:Internal class that is apparently never instantiated", + Justification = "Configuration classes are instantiated through IConfiguration.")] +internal sealed class AzureAIInferenceConfiguration(Uri endpoint, string apiKey, string? serviceId = null) +{ + public Uri Endpoint { get; set; } = endpoint; + public string? ApiKey { get; set; } = apiKey; + public string? ServiceId { get; set; } = serviceId; +} diff --git a/dotnet/src/IntegrationTests/testsettings.json b/dotnet/src/IntegrationTests/testsettings.json index 40c064f078c5..95c4fd2d7f3e 100644 --- a/dotnet/src/IntegrationTests/testsettings.json +++ b/dotnet/src/IntegrationTests/testsettings.json @@ -5,6 +5,11 @@ "ChatModelId": "gpt-4o", "ApiKey": "" }, + "AzureAIInference": { + "ServiceId": "azure-ai-inference", + "Endpoint": "", + "ApiKey": "" + }, "AzureOpenAI": { "ServiceId": "azure-gpt-35-turbo-instruct", "DeploymentName": "gpt-35-turbo-instruct", diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs index 821670e46dbc..01b60b08c9cb 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs @@ -23,6 +23,7 @@ public static void Initialize(IConfigurationRoot configRoot) public static OpenAIConfig OpenAI => LoadSection(); public static OnnxConfig Onnx => LoadSection(); public static AzureOpenAIConfig AzureOpenAI => LoadSection(); + public static AzureAIInferenceConfig AzureAIInference => LoadSection(); public static AzureOpenAIConfig AzureOpenAIImages => LoadSection(); public static AzureOpenAIEmbeddingsConfig AzureOpenAIEmbeddings => LoadSection(); public static AzureAISearchConfig AzureAISearch => LoadSection(); @@ -73,6 +74,13 @@ public class OpenAIConfig public string ApiKey { get; set; } } + public class AzureAIInferenceConfig + { + public string ServiceId { get; set; } + public string Endpoint { get; set; } + public string? ApiKey { get; set; } + } + public class OnnxConfig { public string ModelId { get; set; } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs index a452d979c4f5..e96f1272b32f 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs @@ -48,7 +48,7 @@ public static Task> GetChatMessageContentsAsyn /// /// Get a single chat message content for the prompt and settings. /// - /// The target IChatCompletionSErvice interface to extend. + /// The target interface to extend. /// The standardized prompt input. /// The AI execution settings (optional). /// The containing services, plugins, and other state for use throughout the operation. @@ -66,7 +66,7 @@ public static async Task GetChatMessageContentAsync( /// /// Get a single chat message content for the chat history and settings provided. /// - /// The target IChatCompletionService interface to extend. + /// The target interface to extend. /// The chat history to complete. /// The AI execution settings (optional). /// The containing services, plugins, and other state for use throughout the operation. @@ -85,7 +85,7 @@ public static async Task GetChatMessageContentAsync( /// Get streaming chat message contents for the chat history provided using the specified settings. /// /// Throws if the specified type is not the same or fail to cast - /// The target IChatCompletionService interface to extend. + /// The target interface to extend. /// The standardized prompt input. /// The AI execution settings (optional). /// The containing services, plugins, and other state for use throughout the operation. diff --git a/prompt_template_samples/ChatPlugin/Chat/config.json b/prompt_template_samples/ChatPlugin/Chat/config.json index fa98c67602e8..ae1dba827434 100644 --- a/prompt_template_samples/ChatPlugin/Chat/config.json +++ b/prompt_template_samples/ChatPlugin/Chat/config.json @@ -5,9 +5,9 @@ "default": { "max_tokens": 150, "temperature": 0.9, - "top_p": 0.0, + "top_p": 0.1, "presence_penalty": 0.6, - "frequency_penalty": 0.0, + "frequency_penalty": 0.1, "stop_sequences": [ "Human:", "AI:" diff --git a/prompt_template_samples/FunPlugin/Limerick/config.json b/prompt_template_samples/FunPlugin/Limerick/config.json index f929ede1e31a..659b1f4f897f 100644 --- a/prompt_template_samples/FunPlugin/Limerick/config.json +++ b/prompt_template_samples/FunPlugin/Limerick/config.json @@ -5,9 +5,9 @@ "default": { "max_tokens": 100, "temperature": 0.7, - "top_p": 0, - "presence_penalty": 0, - "frequency_penalty": 0 + "top_p": 0.1, + "presence_penalty": 0.1, + "frequency_penalty": 0.1 } }, "input_variables": [