From e49b7e2b42fc462304f43e02b5b4882189174f47 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 22 Nov 2024 10:23:08 -0500 Subject: [PATCH 1/5] Update to latest M.E.AI --- dotnet/Directory.Packages.props | 8 +- ...eAIInferenceServiceCollectionExtensions.cs | 121 +++++++----------- .../AzureAIInferenceChatCompletionService.cs | 88 +++++-------- .../OllamaServiceCollectionExtensions.cs | 120 ++++++++--------- .../ChatCompletionServiceChatClient.cs | 13 +- .../EmbeddingGenerationServiceExtensions.cs | 13 +- .../AI/ServiceConversionExtensionsTests.cs | 4 +- 7 files changed, 154 insertions(+), 213 deletions(-) diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 8ea8825027bb..7770e35f415f 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -56,15 +56,15 @@ - + - - - + + + diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs index 387d9b89a62a..4b2d3eaa1040 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs @@ -8,8 +8,8 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.ChatCompletion; -using Microsoft.SemanticKernel.Http; namespace Microsoft.SemanticKernel; @@ -38,34 +38,26 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion( { Verify.NotNull(services); - services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var chatClientBuilder = new ChatClientBuilder() - .UseFunctionInvocation(config => - config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - - var logger = serviceProvider.GetService()?.CreateLogger(); - if (logger is not null) - { - chatClientBuilder.UseLogging(logger); - } - var options = new AzureAIInferenceClientOptions(); + + httpClient ??= serviceProvider.GetService(); if (httpClient is not null) { - options.Transport = new HttpClientTransport(HttpClientProvider.GetHttpClient(httpClient, serviceProvider)); + options.Transport = new HttpClientTransport(httpClient); } - return - chatClientBuilder.Use( - new Microsoft.Extensions.AI.AzureAIInferenceChatClient( - modelId: modelId, - chatCompletionsClient: new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options) - ) - ).AsChatCompletionService(); - }); + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return services; + return new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options) + .AsChatClient(modelId) + .AsBuilder() + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(serviceProvider); + }); } /// @@ -88,34 +80,26 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion( { Verify.NotNull(services); - services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var chatClientBuilder = new ChatClientBuilder() - .UseFunctionInvocation(config => - config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - - var logger = serviceProvider.GetService()?.CreateLogger(); - if (logger is not null) - { - chatClientBuilder.UseLogging(logger); - } - var options = new AzureAIInferenceClientOptions(); + + httpClient ??= serviceProvider.GetService(); if (httpClient is not null) { - options.Transport = new HttpClientTransport(HttpClientProvider.GetHttpClient(httpClient, serviceProvider)); + options.Transport = new HttpClientTransport(httpClient); } - return - chatClientBuilder.Use( - new Microsoft.Extensions.AI.AzureAIInferenceChatClient( - modelId: modelId, - chatCompletionsClient: new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options) - ) - ).AsChatCompletionService(); - }); + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return services; + return new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options) + .AsChatClient(modelId) + .AsBuilder() + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(serviceProvider); + }); } /// @@ -133,26 +117,18 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService { Verify.NotNull(services); - services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { chatClient ??= serviceProvider.GetRequiredService(); - - var chatClientBuilder = new ChatClientBuilder() - .UseFunctionInvocation(config => - config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - - var logger = serviceProvider.GetService()?.CreateLogger(); - if (logger is not null) - { - chatClientBuilder.UseLogging(logger); - } - - return chatClientBuilder - .Use(new Microsoft.Extensions.AI.AzureAIInferenceChatClient(chatClient, modelId)) - .AsChatCompletionService(); + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + return chatClient + .AsChatClient(modelId) + .AsBuilder() + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(serviceProvider); }); - - return services; } /// @@ -168,26 +144,17 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService { Verify.NotNull(services); - services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { chatClient ??= serviceProvider.GetRequiredService(); - - var chatClientBuilder = new ChatClientBuilder() - .UseFunctionInvocation(config => - config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - - var logger = serviceProvider.GetService()?.CreateLogger(); - if (logger is not null) - { - chatClientBuilder.UseLogging(logger); - } - - return chatClientBuilder - .Use(chatClient) - .AsChatCompletionService(); + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + return chatClient + .AsBuilder() + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(serviceProvider); }); - - return services; } #region Private diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs index 392f93b47147..868c0c7ff16c 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs @@ -9,6 +9,7 @@ using Azure.Core; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.AzureAIInference.Core; @@ -38,25 +39,16 @@ public AzureAIInferenceChatCompletionService( HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) { - var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService)); - this._core = new( - modelId, - apiKey, - endpoint, - httpClient, - logger); - - var builder = new ChatClientBuilder() - .UseFunctionInvocation(config => - config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - - if (logger is not null) - { - builder = builder.UseLogging(logger); - } - - this._chatService = builder - .Use(this._core.Client.AsChatClient(modelId)) + loggerFactory ??= NullLoggerFactory.Instance; + + this._core = new ChatClientCore(modelId, apiKey, endpoint, httpClient); + + this._chatService = this._core.Client + .AsChatClient(modelId) + .AsBuilder() + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build() .AsChatCompletionService(); } @@ -75,25 +67,16 @@ public AzureAIInferenceChatCompletionService( HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) { - var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService)); - this._core = new( - modelId, - credential, - endpoint, - httpClient, - logger); - - var builder = new ChatClientBuilder() - .UseFunctionInvocation(config => - config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - - if (logger is not null) - { - builder = builder.UseLogging(logger); - } - - this._chatService = builder - .Use(this._core.Client.AsChatClient(modelId)) + loggerFactory ??= NullLoggerFactory.Instance; + + this._core = new ChatClientCore(modelId, credential, endpoint, httpClient); + + this._chatService = this._core.Client + .AsChatClient(modelId) + .AsBuilder() + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build() .AsChatCompletionService(); } @@ -108,23 +91,18 @@ public AzureAIInferenceChatCompletionService( ChatCompletionsClient chatClient, ILoggerFactory? loggerFactory = null) { - var logger = loggerFactory?.CreateLogger(typeof(AzureAIInferenceChatCompletionService)); - this._core = new( - modelId, - chatClient, - logger); - - var builder = new ChatClientBuilder() - .UseFunctionInvocation(config => - config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - - if (logger is not null) - { - builder = builder.UseLogging(logger); - } - - this._chatService = builder - .Use(this._core.Client.AsChatClient(modelId)) + Verify.NotNull(chatClient); + + loggerFactory ??= NullLoggerFactory.Instance; + + this._core = new ChatClientCore(modelId, chatClient); + + this._chatService = chatClient + .AsChatClient(modelId) + .AsBuilder() + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build() .AsChatCompletionService(); } diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs index d53825079721..aebb9f411f20 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs @@ -5,6 +5,7 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Ollama; using Microsoft.SemanticKernel.Embeddings; @@ -110,24 +111,16 @@ public static IServiceCollection AddOllamaChatCompletion( { Verify.NotNull(services); - services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var ollamaClient = new OllamaApiClient(endpoint, modelId); - - var chatClientBuilder = new ChatClientBuilder() - .UseFunctionInvocation(config => - config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - - var logger = serviceProvider.GetService()?.CreateLogger(ollamaClient.GetType()); - if (logger is not null) - { - chatClientBuilder.UseLogging(logger); - } - - return chatClientBuilder.Use(ollamaClient).AsChatCompletionService(serviceProvider); + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + return ((IChatClient)new OllamaApiClient(endpoint, modelId)) + .AsBuilder() + .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(serviceProvider); }); - - return services; } /// @@ -146,26 +139,17 @@ public static IServiceCollection AddOllamaChatCompletion( { Verify.NotNull(services); - services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var ollamaClient = new OllamaApiClient( - client: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), - modelId); - - var chatClientBuilder = new ChatClientBuilder() - .UseFunctionInvocation(config => - config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - - var logger = serviceProvider.GetService()?.CreateLogger(ollamaClient.GetType()); - if (logger is not null) - { - chatClientBuilder.UseLogging(logger); - } - - return chatClientBuilder.Use(ollamaClient).AsChatCompletionService(serviceProvider); + httpClient ??= HttpClientProvider.GetHttpClient(httpClient, serviceProvider); + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + return ((IChatClient)new OllamaApiClient(httpClient, modelId)) + .AsBuilder() + .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(serviceProvider); }); - - return services; } /// @@ -182,10 +166,16 @@ public static IServiceCollection AddOllamaChatCompletion( { Verify.NotNull(services); - services.AddKeyedSingleton(serviceId, (serviceProvider, _) - => ollamaClient.AsChatCompletionService(serviceProvider)); - - return services; + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + return ((IChatClient)ollamaClient) + .AsBuilder() + .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes) + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(serviceProvider); + }); } #endregion @@ -208,22 +198,15 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( { Verify.NotNull(services); - services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var ollamaClient = new OllamaApiClient(endpoint, modelId); - - var builder = new EmbeddingGeneratorBuilder>(); - - var logger = serviceProvider.GetService()?.CreateLogger(ollamaClient.GetType()); - if (logger is not null) - { - builder.UseLogging(logger); - } - - return builder.Use(ollamaClient).AsTextEmbeddingGenerationService(serviceProvider); + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + return ((IEmbeddingGenerator>)new OllamaApiClient(endpoint, modelId)) + .AsBuilder() + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsTextEmbeddingGenerationService(serviceProvider); }); - - return services; } /// @@ -244,19 +227,13 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var ollamaClient = new OllamaApiClient( - client: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), - defaultModel: modelId); - - var builder = new EmbeddingGeneratorBuilder>(); - - var logger = serviceProvider.GetService()?.CreateLogger(ollamaClient.GetType()); - if (logger is not null) - { - builder.UseLogging(logger); - } - - return builder.Use(ollamaClient).AsTextEmbeddingGenerationService(serviceProvider); + httpClient ??= HttpClientProvider.GetHttpClient(httpClient, serviceProvider); + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + return ((IEmbeddingGenerator>)new OllamaApiClient(httpClient, modelId)) + .AsBuilder() + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsTextEmbeddingGenerationService(serviceProvider); }); return services; @@ -276,10 +253,15 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( { Verify.NotNull(services); - services.AddKeyedSingleton(serviceId, (serviceProvider, _) - => ollamaClient.AsTextEmbeddingGenerationService(serviceProvider)); - - return services; + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + return ((IEmbeddingGenerator>)ollamaClient) + .AsBuilder() + .UseLogging(loggerFactory) + .Build(serviceProvider) + .AsTextEmbeddingGenerationService(serviceProvider); + }); } #endregion diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs index cab0bce50d26..34750dfce7bd 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs @@ -74,11 +74,18 @@ public void Dispose() } /// - public TService? GetService(object? key = null) where TService : class + public object? GetService(Type serviceType, object? serviceKey = null) { + if (serviceType is null) + { + throw new ArgumentNullException(nameof(serviceType)); + } + return - typeof(TService) == typeof(IChatClient) ? (TService)(object)this : - this._chatCompletionService as TService; + serviceKey is not null ? null : + serviceType.IsInstanceOfType(this) ? this : + serviceType.IsInstanceOfType(this._chatCompletionService) ? this._chatCompletionService : + null; } /// Converts a to a . diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/Embeddings/EmbeddingGenerationServiceExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/Embeddings/EmbeddingGenerationServiceExtensions.cs index 7ae6593f4d2d..4d6ab46f618d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/Embeddings/EmbeddingGenerationServiceExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/Embeddings/EmbeddingGenerationServiceExtensions.cs @@ -124,11 +124,18 @@ public async Task>> GenerateAsync(IEnu } /// - public TService? GetService(object? key = null) where TService : class + public object? GetService(Type serviceType, object? serviceKey = null) { + if (serviceKey is null) + { + throw new ArgumentNullException(nameof(serviceKey)); + } + return - typeof(TService) == typeof(IEmbeddingGenerator>) ? (TService)(object)this : - this._service as TService; + serviceKey is not null ? null : + serviceType.IsInstanceOfType(this) ? this : + serviceType.IsInstanceOfType(this._service) ? this._service : + null; } } diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs index 09f1966e2837..9f8a60b40098 100644 --- a/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs @@ -672,7 +672,7 @@ public IAsyncEnumerable CompleteStreamingAsync(IL public void Dispose() { } - public TService? GetService(object? key = null) where TService : class + public object? GetService(Type serviceType, object? serviceKey = null) { return null; } @@ -707,7 +707,7 @@ public Task>> GenerateAsync(IEnumerable(object? key = null) where TService : class + public object? GetService(Type serviceType, object? serviceKey = null) { return null; } From d61b4cf1be739ff65b98989b4c940cddeebcfe65 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 22 Nov 2024 10:46:18 -0500 Subject: [PATCH 2/5] Fix test --- .../Services/AzureAIInferenceChatCompletionServiceTests.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs index 417f32cc545b..a8447d4838a3 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Services/AzureAIInferenceChatCompletionServiceTests.cs @@ -10,6 +10,7 @@ using Azure; using Azure.AI.Inference; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.AzureAIInference; @@ -51,7 +52,6 @@ public void ConstructorsWorksAsExpected() { // Arrange using var httpClient = new HttpClient() { BaseAddress = this._endpoint }; - var loggerFactoryMock = new Mock(); ChatCompletionsClient client = new(this._endpoint, new AzureKeyCredential("api-key")); // Act & Assert @@ -60,12 +60,12 @@ public void ConstructorsWorksAsExpected() new AzureAIInferenceChatCompletionService(modelId: "model", httpClient: httpClient, apiKey: null); // Only the HttpClient with a BaseClass defined new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null); // ModelId and endpoint new AzureAIInferenceChatCompletionService(modelId: "model", apiKey: "api-key", endpoint: this._endpoint); // ModelId, apiKey, and endpoint - new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null, loggerFactory: loggerFactoryMock.Object); // Endpoint and loggerFactory + new AzureAIInferenceChatCompletionService(modelId: "model", endpoint: this._endpoint, apiKey: null, loggerFactory: NullLoggerFactory.Instance); // Endpoint and loggerFactory // Breaking Glass constructor new AzureAIInferenceChatCompletionService(modelId: null, chatClient: client); // Client without model new AzureAIInferenceChatCompletionService(modelId: "model", chatClient: client); // Client - new AzureAIInferenceChatCompletionService(modelId: "model", chatClient: client, loggerFactory: loggerFactoryMock.Object); // Client + new AzureAIInferenceChatCompletionService(modelId: "model", chatClient: client, loggerFactory: NullLoggerFactory.Instance); // Client } [Theory] From 5686e615479ffd2842e1cb79db599274f9571559 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 22 Nov 2024 14:53:37 -0500 Subject: [PATCH 3/5] Address feedback --- ...eAIInferenceServiceCollectionExtensions.cs | 68 ++++++++---- .../AzureAIInferenceChatCompletionService.cs | 48 ++++---- .../OllamaServiceCollectionExtensions.cs | 103 ++++++++++++------ .../ChatCompletionServiceChatClient.cs | 5 +- .../EmbeddingGenerationServiceExtensions.cs | 5 +- 5 files changed, 141 insertions(+), 88 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs index 4b2d3eaa1040..e77be0cb675f 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs @@ -48,15 +48,19 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion( options.Transport = new HttpClientTransport(httpClient); } - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + var loggerFactory = serviceProvider.GetService(); - return new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options) + var builder = new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options) .AsChatClient(modelId) .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsChatCompletionService(serviceProvider); + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } @@ -90,15 +94,19 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion( options.Transport = new HttpClientTransport(httpClient); } - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; + var loggerFactory = serviceProvider.GetService(); - return new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options) + var builder = new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options) .AsChatClient(modelId) .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsChatCompletionService(serviceProvider); + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } @@ -120,14 +128,20 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { chatClient ??= serviceProvider.GetRequiredService(); - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return chatClient + + var loggerFactory = serviceProvider.GetService(); + + var builder = chatClient .AsChatClient(modelId) .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsChatCompletionService(serviceProvider); + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } @@ -147,13 +161,19 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { chatClient ??= serviceProvider.GetRequiredService(); - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return chatClient + + var loggerFactory = serviceProvider.GetService(); + + var builder = chatClient .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsChatCompletionService(serviceProvider); + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs index 868c0c7ff16c..63e072b1d969 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs @@ -39,17 +39,19 @@ public AzureAIInferenceChatCompletionService( HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) { - loggerFactory ??= NullLoggerFactory.Instance; - this._core = new ChatClientCore(modelId, apiKey, endpoint, httpClient); - this._chatService = this._core.Client + var builder = this._core.Client .AsChatClient(modelId) .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build() - .AsChatCompletionService(); + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + this._chatService = builder.Build().AsChatCompletionService(); } /// @@ -67,17 +69,19 @@ public AzureAIInferenceChatCompletionService( HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) { - loggerFactory ??= NullLoggerFactory.Instance; - this._core = new ChatClientCore(modelId, credential, endpoint, httpClient); - this._chatService = this._core.Client + var builder = this._core.Client .AsChatClient(modelId) .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build() - .AsChatCompletionService(); + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + this._chatService = builder.Build().AsChatCompletionService(); } /// @@ -93,17 +97,19 @@ public AzureAIInferenceChatCompletionService( { Verify.NotNull(chatClient); - loggerFactory ??= NullLoggerFactory.Instance; - this._core = new ChatClientCore(modelId, chatClient); - this._chatService = chatClient + var builder = chatClient .AsChatClient(modelId) .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build() - .AsChatCompletionService(); + .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + this._chatService = builder.Build().AsChatCompletionService(); } /// diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs index aebb9f411f20..5e04becdc4cd 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs @@ -113,13 +113,18 @@ public static IServiceCollection AddOllamaChatCompletion( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return ((IChatClient)new OllamaApiClient(endpoint, modelId)) + var loggerFactory = serviceProvider.GetService(); + + var builder = ((IChatClient)new OllamaApiClient(endpoint, modelId)) .AsBuilder() - .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsChatCompletionService(serviceProvider); + .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } @@ -142,13 +147,19 @@ public static IServiceCollection AddOllamaChatCompletion( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { httpClient ??= HttpClientProvider.GetHttpClient(httpClient, serviceProvider); - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return ((IChatClient)new OllamaApiClient(httpClient, modelId)) + + var loggerFactory = serviceProvider.GetService(); + + var builder = ((IChatClient)new OllamaApiClient(httpClient, modelId)) .AsBuilder() - .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsChatCompletionService(serviceProvider); + .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } @@ -168,13 +179,18 @@ public static IServiceCollection AddOllamaChatCompletion( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return ((IChatClient)ollamaClient) + var loggerFactory = serviceProvider.GetService(); + + var builder = ((IChatClient)ollamaClient) .AsBuilder() - .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes) - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsChatCompletionService(serviceProvider); + .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } @@ -200,12 +216,18 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return ((IEmbeddingGenerator>)new OllamaApiClient(endpoint, modelId)) + var loggerFactory = serviceProvider.GetService(); + + var builder = ((IEmbeddingGenerator>)new OllamaApiClient(endpoint, modelId)) .AsBuilder() - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsTextEmbeddingGenerationService(serviceProvider); + .UseLogging(loggerFactory); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsTextEmbeddingGenerationService(serviceProvider); }); } @@ -228,12 +250,18 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { httpClient ??= HttpClientProvider.GetHttpClient(httpClient, serviceProvider); - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return ((IEmbeddingGenerator>)new OllamaApiClient(httpClient, modelId)) - .AsBuilder() - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsTextEmbeddingGenerationService(serviceProvider); + + var loggerFactory = serviceProvider.GetService(); + + var builder = ((IEmbeddingGenerator>)new OllamaApiClient(httpClient, modelId)) + .AsBuilder(); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsTextEmbeddingGenerationService(serviceProvider); }); return services; @@ -255,12 +283,17 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => { - var loggerFactory = serviceProvider.GetService() ?? NullLoggerFactory.Instance; - return ((IEmbeddingGenerator>)ollamaClient) - .AsBuilder() - .UseLogging(loggerFactory) - .Build(serviceProvider) - .AsTextEmbeddingGenerationService(serviceProvider); + var loggerFactory = serviceProvider.GetService(); + + var builder = ((IEmbeddingGenerator>)ollamaClient) + .AsBuilder(); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(serviceProvider).AsTextEmbeddingGenerationService(serviceProvider); }); } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs index 34750dfce7bd..ba9d4e80fc80 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs @@ -76,10 +76,7 @@ public void Dispose() /// public object? GetService(Type serviceType, object? serviceKey = null) { - if (serviceType is null) - { - throw new ArgumentNullException(nameof(serviceType)); - } + Verify.NotNull(serviceType); return serviceKey is not null ? null : diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/Embeddings/EmbeddingGenerationServiceExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/Embeddings/EmbeddingGenerationServiceExtensions.cs index 4d6ab46f618d..c060c3f0d523 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/Embeddings/EmbeddingGenerationServiceExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/Embeddings/EmbeddingGenerationServiceExtensions.cs @@ -126,10 +126,7 @@ public async Task>> GenerateAsync(IEnu /// public object? GetService(Type serviceType, object? serviceKey = null) { - if (serviceKey is null) - { - throw new ArgumentNullException(nameof(serviceKey)); - } + Verify.NotNull(serviceType); return serviceKey is not null ? null : From eab4dfaad8394551010a465fcb7690668d0c6557 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 22 Nov 2024 15:03:23 -0500 Subject: [PATCH 4/5] Fix formatting --- .../Extensions/AzureAIInferenceServiceCollectionExtensions.cs | 1 - .../Services/AzureAIInferenceChatCompletionService.cs | 1 - 2 files changed, 2 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs index e77be0cb675f..c932c27c3831 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs @@ -8,7 +8,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel; diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs index 63e072b1d969..a940151e4ec4 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Services/AzureAIInferenceChatCompletionService.cs @@ -9,7 +9,6 @@ using Azure.Core; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.AzureAIInference.Core; From 89881a0ec60325c3794245b6dd44851dd5a2973d Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sat, 23 Nov 2024 07:59:24 -0500 Subject: [PATCH 5/5] Fix formatting --- .../Extensions/OllamaServiceCollectionExtensions.cs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs index 5e04becdc4cd..960466bd9f5d 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs @@ -5,7 +5,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Ollama; using Microsoft.SemanticKernel.Embeddings; @@ -219,8 +218,7 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( var loggerFactory = serviceProvider.GetService(); var builder = ((IEmbeddingGenerator>)new OllamaApiClient(endpoint, modelId)) - .AsBuilder() - .UseLogging(loggerFactory); + .AsBuilder(); if (loggerFactory is not null) {