Skip to content

Commit

Permalink
.Net: MS AI Azure Inference Connector Update (#9640)
Browse files Browse the repository at this point in the history
### Motivation and Context

#### Updates Azure AI Inference Connector to use `Microsoft Extensions
AI`.

- Enables Logging 
- Enables Function Calling
- Adds extra extension for `AzureAIInferenceChatClient`
- Updates Demos/Concepts with the new pattern.

#### ⚠️ Breaking Changes
- ChatCompletion Service is Deprecated (Obsoleted) in favor of 
- `Microsoft.Extensions.AI.AzureAIInferenceChatClient`
`.AsChatCompletionService()` or
- `Azure.AI.Inference.ChatCompletionsClient`
`.AsChatClient().AsChatCompletionService()`
    
- `modelId` is required for all ChatCompletion extensions
  • Loading branch information
RogerBarreto authored Nov 13, 2024
1 parent 50bc6f3 commit 4a92f34
Show file tree
Hide file tree
Showing 21 changed files with 654 additions and 541 deletions.
1 change: 1 addition & 0 deletions .github/workflows/dotnet-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ jobs:
# Azure AI Inference Endpoint
AzureAIInference__ApiKey: ${{ secrets.AZUREAIINFERENCE__APIKEY }}
AzureAIInference__Endpoint: ${{ secrets.AZUREAIINFERENCE__ENDPOINT }}
AzureAIInference__ChatModelId: ${{ vars.AZUREAIINFERENCE__CHATMODELID }}

# Generate test reports and check coverage
- name: Generate test reports
Expand Down
1 change: 1 addition & 0 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
<!-- Microsoft.Extensions.* -->
<PackageVersion Include="Microsoft.Extensions.AI" Version="9.0.0-preview.9.24525.1" />
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
<PackageVersion Include="Microsoft.Extensions.AI.AzureAIInference" Version="9.0.0-preview.9.24525.1" />
<PackageVersion Include="Microsoft.Extensions.Configuration" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Configuration.Binder" Version="8.0.2" />
<PackageVersion Include="Microsoft.Extensions.Configuration.EnvironmentVariables" Version="8.0.0" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text;
using Azure.AI.Inference;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureAIInference;

namespace ChatCompletion;

Expand All @@ -15,9 +16,13 @@ public async Task ServicePromptAsync()
{
Console.WriteLine("======== Azure AI Inference - Chat Completion ========");

var chatService = new AzureAIInferenceChatCompletionService(
endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint),
apiKey: TestConfiguration.AzureAIInference.ApiKey);
Assert.NotNull(TestConfiguration.AzureAIInference.ApiKey);

var chatService = new ChatCompletionsClient(
endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint),
credential: new Azure.AzureKeyCredential(TestConfiguration.AzureAIInference.ApiKey))
.AsChatClient(TestConfiguration.AzureAIInference.ChatModelId)
.AsChatCompletionService();

Console.WriteLine("Chat content:");
Console.WriteLine("------------------------");
Expand Down Expand Up @@ -81,6 +86,7 @@ public async Task ChatPromptAsync()

var kernel = Kernel.CreateBuilder()
.AddAzureAIInferenceChatCompletion(
modelId: TestConfiguration.AzureAIInference.ChatModelId,
endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint),
apiKey: TestConfiguration.AzureAIInference.ApiKey)
.Build();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text;
using Azure.AI.Inference;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureAIInference;

namespace ChatCompletion;

Expand All @@ -20,9 +21,11 @@ public Task StreamChatAsync()
{
Console.WriteLine("======== Azure AI Inference - Chat Completion Streaming ========");

var chatService = new AzureAIInferenceChatCompletionService(
endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint),
apiKey: TestConfiguration.AzureAIInference.ApiKey);
var chatService = new ChatCompletionsClient(
endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint),
credential: new Azure.AzureKeyCredential(TestConfiguration.AzureAIInference.ApiKey!))
.AsChatClient(TestConfiguration.AzureAIInference.ChatModelId)
.AsChatCompletionService();

return this.StartStreamingChatAsync(chatService);
}
Expand All @@ -42,6 +45,7 @@ public async Task StreamChatPromptAsync()

var kernel = Kernel.CreateBuilder()
.AddAzureAIInferenceChatCompletion(
modelId: TestConfiguration.AzureAIInference.ChatModelId,
endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint),
apiKey: TestConfiguration.AzureAIInference.ApiKey)
.Build();
Expand All @@ -67,9 +71,11 @@ public async Task StreamTextFromChatAsync()
Console.WriteLine("======== Stream Text from Chat Content ========");

// Create chat completion service
var chatService = new AzureAIInferenceChatCompletionService(
endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint),
apiKey: TestConfiguration.AzureAIInference.ApiKey);
var chatService = new ChatCompletionsClient(
endpoint: new Uri(TestConfiguration.AzureAIInference.Endpoint),
credential: new Azure.AzureKeyCredential(TestConfiguration.AzureAIInference.ApiKey!))
.AsChatClient(TestConfiguration.AzureAIInference.ChatModelId)
.AsChatCompletionService();

// Create chat history with initial system and user messages
ChatHistory chatHistory = new("You are a librarian, an expert on books.");
Expand Down
1 change: 1 addition & 0 deletions dotnet/samples/Concepts/Concepts.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
</PackageReference>
<PackageReference Include="Azure.Identity" />
<PackageReference Include="Microsoft.Extensions.Configuration" />
<PackageReference Include="Microsoft.Extensions.AI.AzureAIInference" />
<PackageReference Include="Microsoft.Extensions.Configuration.Binder" />
<PackageReference Include="Microsoft.Extensions.Configuration.EnvironmentVariables" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" />
Expand Down
1 change: 1 addition & 0 deletions dotnet/samples/Demos/AIModelRouter/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ private static async Task Main(string[] args)
{
services.AddAzureAIInferenceChatCompletion(
serviceId: "azureai",
modelId: config["AzureAIInference:ChatModelId"]!,
endpoint: new Uri(config["AzureAIInference:Endpoint"]!),
apiKey: config["AzureAIInference:ApiKey"]);

Expand Down
3 changes: 3 additions & 0 deletions dotnet/samples/Demos/AIModelRouter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ dotnet user-secrets set "OpenAI:ChatModelId" ".. chat completion model .." (defa
dotnet user-secrets set "AzureOpenAI:Endpoint" ".. endpoint .."
dotnet user-secrets set "AzureOpenAI:ChatDeploymentName" ".. chat deployment name .." (default: gpt-4o)
dotnet user-secrets set "AzureOpenAI:ApiKey" ".. api key .." (default: Authenticate with Azure CLI credential)
dotnet user-secrets set "AzureAIInference:ApiKey" ".. api key .."
dotnet user-secrets set "AzureAIInference:Endpoint" ".. endpoint .."
dotnet user-secrets set "AzureAIInference:ChatModelId" ".. chat completion model .."
dotnet user-secrets set "LMStudio:Endpoint" ".. endpoint .." (default: http://localhost:1234)
dotnet user-secrets set "Ollama:ModelId" ".. model id .."
dotnet user-secrets set "Ollama:Endpoint" ".. endpoint .." (default: http://localhost:11434)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="System.Numerics.Tensors" />
<PackageReference Include="System.Text.Json" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureAIInference;
using Xunit;

namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Extensions;
Expand Down Expand Up @@ -37,7 +36,7 @@ public void KernelBuilderAddAzureAIInferenceChatCompletionAddsValidService(Initi

// Assert
var chatCompletionService = builder.Build().GetRequiredService<IChatCompletionService>();
Assert.True(chatCompletionService is AzureAIInferenceChatCompletionService);
Assert.Equal("ChatClientChatCompletionService", chatCompletionService.GetType().Name);
}

public enum InitializationType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
using System;
using Azure;
using Azure.AI.Inference;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureAIInference;
using Xunit;

namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Extensions;
Expand All @@ -18,11 +18,13 @@ public sealed class AzureAIInferenceServiceCollectionExtensionsTests
[Theory]
[InlineData(InitializationType.ApiKey)]
[InlineData(InitializationType.ClientInline)]
[InlineData(InitializationType.ChatClientInline)]
[InlineData(InitializationType.ClientInServiceProvider)]
public void ItCanAddChatCompletionService(InitializationType type)
{
// Arrange
var client = new ChatCompletionsClient(this._endpoint, new AzureKeyCredential("key"));
using var chatClient = new AzureAIInferenceChatClient(client, "model-id");
var builder = Kernel.CreateBuilder();

builder.Services.AddSingleton(client);
Expand All @@ -32,19 +34,21 @@ public void ItCanAddChatCompletionService(InitializationType type)
{
InitializationType.ApiKey => builder.Services.AddAzureAIInferenceChatCompletion("modelId", "api-key", this._endpoint),
InitializationType.ClientInline => builder.Services.AddAzureAIInferenceChatCompletion("modelId", client),
InitializationType.ChatClientInline => builder.Services.AddAzureAIInferenceChatCompletion(chatClient),
InitializationType.ClientInServiceProvider => builder.Services.AddAzureAIInferenceChatCompletion("modelId", chatClient: null),
_ => builder.Services
};

// Assert
var chatCompletionService = builder.Build().GetRequiredService<IChatCompletionService>();
Assert.True(chatCompletionService is AzureAIInferenceChatCompletionService);
Assert.Equal("ChatClientChatCompletionService", chatCompletionService.GetType().Name);
}

public enum InitializationType
{
ApiKey,
ClientInline,
ChatClientInline,
ClientInServiceProvider,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace SemanticKernel.Connectors.AzureAIInference.UnitTests.Services;
/// <summary>
/// Tests for the <see cref="AzureAIInferenceChatCompletionService"/> class.
/// </summary>
[Obsolete("Keeping this test until the service is removed from code-base")]
public sealed class AzureAIInferenceChatCompletionServiceTests : IDisposable
{
private readonly Uri _endpoint = new("https://localhost:1234");
Expand Down Expand Up @@ -55,11 +56,11 @@ public void ConstructorsWorksAsExpected()

// Act & Assert
// Endpoint constructor
new AzureAIInferenceChatCompletionService(endpoint: this._endpoint, apiKey: null); // Only the endpoint
new AzureAIInferenceChatCompletionService(httpClient: httpClient, apiKey: null); // Only the HttpClient with a BaseClass defined
new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null); // Only the endpoint
new AzureAIInferenceChatCompletionService(modelId: "model", httpClient: httpClient, apiKey: null); // Only the HttpClient with a BaseClass defined
new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null); // ModelId and endpoint
new AzureAIInferenceChatCompletionService(modelId: "model", apiKey: "api-key", endpoint: this._endpoint); // ModelId, apiKey, and endpoint
new AzureAIInferenceChatCompletionService(endpoint: this._endpoint, apiKey: null, loggerFactory: loggerFactoryMock.Object); // Endpoint and loggerFactory
new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null, loggerFactory: loggerFactoryMock.Object); // Endpoint and loggerFactory

// Breaking Glass constructor
new AzureAIInferenceChatCompletionService(modelId: null, chatClient: client); // Client without model
Expand Down Expand Up @@ -132,14 +133,14 @@ public async Task ItUsesHttpClientBaseAddressWhenNoEndpointIsProvidedAsync()
public void ItThrowsIfNoEndpointOrNoHttpClientBaseAddressIsProvided()
{
// Act & Assert
Assert.Throws<ArgumentNullException>(() => new AzureAIInferenceChatCompletionService(endpoint: null, httpClient: this._httpClient));
Assert.Throws<ArgumentNullException>(() => new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: null, httpClient: this._httpClient));
}

[Fact]
public async Task ItGetChatMessageContentsShouldHaveModelIdDefinedAsync()
{
// Arrange
var chatCompletion = new AzureAIInferenceChatCompletionService(apiKey: "NOKEY", httpClient: this._httpClientWithBaseAddress);
var chatCompletion = new AzureAIInferenceChatCompletionService(modelId: "model", apiKey: "NOKEY", httpClient: this._httpClientWithBaseAddress);
this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
{ Content = this.CreateDefaultStringContent() };

Expand All @@ -158,7 +159,7 @@ public async Task ItGetChatMessageContentsShouldHaveModelIdDefinedAsync()
public async Task GetStreamingChatMessageContentsWorksCorrectlyAsync()
{
// Arrange
var service = new AzureAIInferenceChatCompletionService(httpClient: this._httpClientWithBaseAddress);
var service = new AzureAIInferenceChatCompletionService(modelId: "model", httpClient: this._httpClientWithBaseAddress);
await using var stream = File.OpenRead("TestData/chat_completion_streaming_response.txt");

this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK)
Expand All @@ -174,7 +175,9 @@ public async Task GetStreamingChatMessageContentsWorksCorrectlyAsync()

await enumerator.MoveNextAsync();
Assert.Equal("Test content", enumerator.Current.Content);
Assert.Equal("stop", enumerator.Current.Metadata?["FinishReason"]);
Assert.IsType<StreamingChatCompletionsUpdate>(enumerator.Current.InnerContent);
StreamingChatCompletionsUpdate innerContent = (StreamingChatCompletionsUpdate)enumerator.Current.InnerContent;
Assert.Equal("stop", innerContent.FinishReason);
}

[Fact]
Expand Down Expand Up @@ -210,7 +213,7 @@ public async Task GetChatMessageContentsWithChatMessageContentItemCollectionCorr

Assert.Equal(3, messages.GetArrayLength());

Assert.Equal(Prompt, messages[0].GetProperty("content").GetString());
Assert.Contains(Prompt, messages[0].GetProperty("content").GetRawText());
Assert.Equal("user", messages[0].GetProperty("role").GetString());

Assert.Equal(AssistantMessage, messages[1].GetProperty("content").GetString());
Expand Down Expand Up @@ -250,7 +253,7 @@ public async Task GetChatMessageInResponseFormatsAsync(string formatType, string
break;
}

var sut = new AzureAIInferenceChatCompletionService(httpClient: this._httpClientWithBaseAddress);
var sut = new AzureAIInferenceChatCompletionService("any", httpClient: this._httpClientWithBaseAddress);
AzureAIInferencePromptExecutionSettings executionSettings = new() { ResponseFormat = format };

this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.Inference" />
<PackageReference Include="Microsoft.Extensions.AI" />
<PackageReference Include="Microsoft.Extensions.AI.AzureAIInference" />
</ItemGroup>
</Project>
Loading

0 comments on commit 4a92f34

Please sign in to comment.