diff --git a/dotnet/samples/Concepts/ChatCompletion/Onnx_ChatCompletionStreaming.cs b/dotnet/samples/Concepts/ChatCompletion/Onnx_ChatCompletionStreaming.cs index d6ad1f05e7f2..d07c6e3240d1 100644 --- a/dotnet/samples/Concepts/ChatCompletion/Onnx_ChatCompletionStreaming.cs +++ b/dotnet/samples/Concepts/ChatCompletion/Onnx_ChatCompletionStreaming.cs @@ -135,7 +135,7 @@ public async Task StreamTextFromChatAsync() } } - private async Task StartStreamingChatAsync(IChatCompletionService chatCompletionService) + private async Task StartStreamingChatAsync(OnnxRuntimeGenAIChatCompletionService chatCompletionService) { Console.WriteLine("Chat content:"); Console.WriteLine("------------------------"); @@ -158,7 +158,7 @@ private async Task StartStreamingChatAsync(IChatCompletionService chatCompletion await StreamMessageOutputAsync(chatCompletionService, chatHistory, AuthorRole.Assistant); } - private async Task StreamMessageOutputAsync(IChatCompletionService chatCompletionService, ChatHistory chatHistory, AuthorRole authorRole) + private async Task StreamMessageOutputAsync(OnnxRuntimeGenAIChatCompletionService chatCompletionService, ChatHistory chatHistory, AuthorRole authorRole) { bool roleWritten = false; string fullMessage = string.Empty; diff --git a/dotnet/src/Connectors/Connectors.Onnx/OnnxRuntimeGenAIChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Onnx/OnnxRuntimeGenAIChatCompletionService.cs index 8a6210253729..2e9e16212a0b 100644 --- a/dotnet/src/Connectors/Connectors.Onnx/OnnxRuntimeGenAIChatCompletionService.cs +++ b/dotnet/src/Connectors/Connectors.Onnx/OnnxRuntimeGenAIChatCompletionService.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; @@ -18,13 +17,11 @@ namespace Microsoft.SemanticKernel.Connectors.Onnx; /// /// Represents a chat completion service using OnnxRuntimeGenAI. /// -public sealed class OnnxRuntimeGenAIChatCompletionService : IChatCompletionService, IDisposable +public sealed class OnnxRuntimeGenAIChatCompletionService : IChatCompletionService { private readonly string _modelId; private readonly string _modelPath; private readonly JsonSerializerOptions? _jsonSerializerOptions; - private Model? _model; - private Tokenizer? _tokenizer; private Dictionary AttributesInternal { get; } = new(); @@ -90,13 +87,17 @@ private async IAsyncEnumerable RunInferenceAsync(ChatHistory chatHistory OnnxRuntimeGenAIPromptExecutionSettings onnxPromptExecutionSettings = this.GetOnnxPromptExecutionSettingsSettings(executionSettings); var prompt = this.GetPrompt(chatHistory, onnxPromptExecutionSettings); - var tokens = this.GetTokenizer().Encode(prompt); - using var generatorParams = new GeneratorParams(this.GetModel()); + using var ogaHandle = new OgaHandle(); + using var model = new Model(this._modelPath); + using var tokenizer = new Tokenizer(model); + + var tokens = tokenizer.Encode(prompt); + + using var generatorParams = new GeneratorParams(model); this.UpdateGeneratorParamsFromPromptExecutionSettings(generatorParams, onnxPromptExecutionSettings); generatorParams.SetInputSequences(tokens); - - using var generator = new Generator(this.GetModel(), generatorParams); + using var generator = new Generator(model, generatorParams); bool removeNextTokenStartingWithSpace = true; while (!generator.IsDone()) @@ -110,7 +111,7 @@ private async IAsyncEnumerable RunInferenceAsync(ChatHistory chatHistory var outputTokens = generator.GetSequence(0); var newToken = outputTokens.Slice(outputTokens.Length - 1, 1); - string output = this.GetTokenizer().Decode(newToken); + string output = tokenizer.Decode(newToken); if (removeNextTokenStartingWithSpace && output[0] == ' ') { @@ -123,10 +124,6 @@ private async IAsyncEnumerable RunInferenceAsync(ChatHistory chatHistory } } - private Model GetModel() => this._model ??= new Model(this._modelPath); - - private Tokenizer GetTokenizer() => this._tokenizer ??= new Tokenizer(this.GetModel()); - private string GetPrompt(ChatHistory chatHistory, OnnxRuntimeGenAIPromptExecutionSettings onnxRuntimeGenAIPromptExecutionSettings) { var promptBuilder = new StringBuilder(); @@ -206,11 +203,4 @@ private OnnxRuntimeGenAIPromptExecutionSettings GetOnnxPromptExecutionSettingsSe return OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings); } - - /// - public void Dispose() - { - this._tokenizer?.Dispose(); - this._model?.Dispose(); - } } diff --git a/dotnet/src/IntegrationTests/Connectors/Onnx/OnnxRuntimeGenAIChatCompletionServiceTests.cs b/dotnet/src/IntegrationTests/Connectors/Onnx/OnnxRuntimeGenAIChatCompletionServiceTests.cs index c6359e3b17a5..c042f633d495 100644 --- a/dotnet/src/IntegrationTests/Connectors/Onnx/OnnxRuntimeGenAIChatCompletionServiceTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Onnx/OnnxRuntimeGenAIChatCompletionServiceTests.cs @@ -57,7 +57,7 @@ public async Task ItCanUseKernelInvokeStreamingAsyncAsync() [Fact(Skip = "For manual verification only")] public async Task ItCanUseServiceGetStreamingChatMessageContentsAsync() { - using var chat = CreateService(); + var chat = CreateService(); ChatHistory history = []; history.AddUserMessage("Where is the most famous fish market in Seattle, Washington, USA?"); @@ -76,7 +76,7 @@ public async Task ItCanUseServiceGetStreamingChatMessageContentsAsync() [Fact(Skip = "For manual verification only")] public async Task ItCanUseServiceGetChatMessageContentsAsync() { - using var chat = CreateService(); + var chat = CreateService(); ChatHistory history = []; history.AddUserMessage("Where is the most famous fish market in Seattle, Washington, USA?");