diff --git a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj index 4d5a3deb9906..9ce032ddcb68 100644 --- a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj +++ b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj @@ -12,6 +12,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 4b948140348f..bdd12c5123c5 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -12,6 +12,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.FunctionCalling; using Microsoft.SemanticKernel.Diagnostics; using Microsoft.SemanticKernel.Http; using Microsoft.SemanticKernel.Text; @@ -28,30 +29,10 @@ internal sealed class GeminiChatCompletionClient : ClientBase private readonly string _modelId; private readonly Uri _chatGenerationEndpoint; private readonly Uri _chatStreamingEndpoint; + private readonly FunctionCallsProcessor _functionCallsProcessor; private static readonly string s_namespace = typeof(GoogleAIGeminiChatCompletionService).Namespace!; - /// - /// The maximum number of auto-invokes that can be in-flight at any given time as part of the current - /// asynchronous chain of execution. - /// - /// - /// This is a fail-safe mechanism. If someone accidentally manages to set up execution settings in such a way that - /// auto-invocation is invoked recursively, and in particular where a prompt function is able to auto-invoke itself, - /// we could end up in an infinite loop. This const is a backstop against that happening. We should never come close - /// to this limit, but if we do, auto-invoke will be disabled for the current flow in order to prevent runaway execution. - /// With the current setup, the way this could possibly happen is if a prompt function is configured with built-in - /// execution settings that opt-in to auto-invocation of everything in the kernel, in which case the invocation of that - /// prompt function could advertise itself as a candidate for auto-invocation. We don't want to outright block that, - /// if that's something a developer has asked to do (e.g. it might be invoked with different arguments than its parent - /// was invoked with), but we do want to limit it. This limit is arbitrary and can be tweaked in the future and/or made - /// configurable should need arise. - /// - private const int MaxInflightAutoInvokes = 128; - - /// Tracking for . - private static readonly AsyncLocal s_inflightAutoInvokes = new(); - /// /// Instance of for metrics. /// @@ -84,6 +65,12 @@ internal sealed class GeminiChatCompletionClient : ClientBase unit: "{token}", description: "Number of tokens used"); + private sealed record ToolCallingConfig( + IList? Tools, + GeminiFunctionCallingMode? Mode, + bool AutoInvoke, + FunctionChoiceBehaviorOptions? Options); + /// /// Represents a client for interacting with the chat completion Gemini model via GoogleAI. /// @@ -108,6 +95,7 @@ public GeminiChatCompletionClient( string versionSubLink = GetApiVersionSubLink(apiVersion); this._modelId = modelId; + this._functionCallsProcessor = new FunctionCallsProcessor(this.Logger); this._chatGenerationEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._modelId}:generateContent?key={apiKey}"); this._chatStreamingEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._modelId}:streamGenerateContent?key={apiKey}&alt=sse"); } @@ -142,6 +130,7 @@ public GeminiChatCompletionClient( string versionSubLink = GetApiVersionSubLink(apiVersion); this._modelId = modelId; + this._functionCallsProcessor = new FunctionCallsProcessor(this.Logger); this._chatGenerationEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:generateContent"); this._chatStreamingEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:streamGenerateContent?alt=sse"); } @@ -162,11 +151,16 @@ public async Task> GenerateChatMessageAsync( { var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings); - for (state.Iteration = 1; ; state.Iteration++) + for (state.RequestIndex = 0;; state.RequestIndex++) { + // TODO: do something with this variable + var functionCallingConfig = this.GetFunctionCallingConfiguration(state); + + // TODO: Here should be request created not above loop + List chatResponses; using (var activity = ModelDiagnostics.StartCompletionActivity( - this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings)) + this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings)) { GeminiResponse geminiResponse; try @@ -190,22 +184,38 @@ public async Task> GenerateChatMessageAsync( // If we don't want to attempt to invoke any functions, just return the result. // Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail. - if (!state.AutoInvoke || chatResponses.Count != 1) + if (!state.AutoInvoke || chatResponses.Count == 0) { return chatResponses; } state.LastMessage = chatResponses[0]; + // TODO: will ToolCalls property shoul be removed from GeminiChatMessageContent? if (state.LastMessage.ToolCalls is null) { return chatResponses; } - // ToolCallBehavior is not null because we are in auto-invoke mode but we check it again to be sure it wasn't changed in the meantime - Verify.NotNull(state.ExecutionSettings.ToolCallBehavior); - - state.AddLastMessageToChatHistoryAndRequest(); - await this.ProcessFunctionsAsync(state, cancellationToken).ConfigureAwait(false); + // TODO: to remove? + // state.AddLastMessageToChatHistoryAndRequest(); + + // Process function calls by invoking the functions and adding the results to the chat history. + // Each function call will trigger auto-function-invocation filters, which can terminate the process. + // In such cases, we'll return the last message in the chat history. + var lastMessage = await this._functionCallsProcessor.ProcessFunctionCallsAsync( + state.LastMessage, + chatHistory, + state.RequestIndex, + content => IsRequestableTool(state.LastMessage.ToolCalls, content), + functionCallingConfig.Options ?? new FunctionChoiceBehaviorOptions(), + kernel, + isStreaming: false, + cancellationToken).ConfigureAwait(false); + + if (lastMessage != null) + { + return [lastMessage]; + } } } @@ -225,10 +235,10 @@ public async IAsyncEnumerable StreamGenerateChatMes { var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings); - for (state.Iteration = 1; ; state.Iteration++) + for (state.RequestIndex = 1;; state.RequestIndex++) { using (var activity = ModelDiagnostics.StartCompletionActivity( - this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings)) + this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings)) { HttpResponseMessage? httpResponseMessage = null; Stream? responseStream = null; @@ -292,6 +302,79 @@ public async IAsyncEnumerable StreamGenerateChatMes } } + private ToolCallingConfig GetFunctionCallingConfiguration(ChatCompletionState state) + { + // If neither behavior is specified, we just return default configuration with no tool and no choice + if (state.ExecutionSettings.FunctionChoiceBehavior is null) + { + return new ToolCallingConfig(Tools: null, Mode: null, AutoInvoke: false, Options: null); + } + + return this.ConfigureFunctionCalling(state); + } + + private ToolCallingConfig ConfigureFunctionCalling(ChatCompletionState state) + { + var config = + this._functionCallsProcessor.GetConfiguration(state.ExecutionSettings.FunctionChoiceBehavior, state.ChatHistory, state.RequestIndex, state.Kernel); + + IList? tools = null; + GeminiFunctionCallingMode? toolMode = null; + bool autoInvoke = config?.AutoInvoke ?? false; + + if (config?.Functions is { Count: > 0 } functions) + { + if (config.Choice == FunctionChoice.Auto) + { + toolMode = GeminiFunctionCallingMode.Default; + } + else if (config.Choice == FunctionChoice.Required) + { + toolMode = GeminiFunctionCallingMode.Any; + } + else if (config.Choice == FunctionChoice.None) + { + toolMode = GeminiFunctionCallingMode.None; + } + else + { + throw new NotSupportedException($"Unsupported function choice '{config.Choice}'."); + } + + tools = []; + + foreach (var function in functions) + { + tools.Add(function.Metadata.ToOpenAIFunction().ToFunctionDefinition()); + } + } + + return new ToolCallingConfig( + Tools: tools, + Mode: toolMode ?? GeminiFunctionCallingMode.None, + AutoInvoke: autoInvoke, + Options: config?.Options); + } + + /// Checks if a tool call is for a function that was defined. + private static bool IsRequestableTool(IReadOnlyList tools, FunctionCallContent functionCallContent) + { + foreach (var tool in tools) + { + if (string.Equals(tool.FunctionName, + FunctionName.ToFullyQualifiedName( + functionCallContent.FunctionName, + functionCallContent.PluginName, + GeminiFunction.NameSeparator), + StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } + + return false; + } + private ChatCompletionState ValidateInputAndCreateChatCompletionState( ChatHistory chatHistory, Kernel? kernel, @@ -391,7 +474,7 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation // Clear the tools. If we end up wanting to use tools, we'll reset it to the desired value. state.GeminiRequest.Tools = null; - if (state.Iteration >= state.ExecutionSettings.ToolCallBehavior!.MaximumUseAttempts) + if (state.RequestIndex >= state.ExecutionSettings.ToolCallBehavior!.MaximumUseAttempts) { // Don't add any tools as we've reached the maximum attempts limit. if (this.Logger.IsEnabled(LogLevel.Debug)) @@ -408,7 +491,7 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation } // Disable auto invocation if we've exceeded the allowed limit. - if (state.Iteration >= state.ExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts) + if (state.RequestIndex >= state.ExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts) { state.AutoInvoke = false; if (this.Logger.IsEnabled(LogLevel.Debug)) @@ -481,11 +564,6 @@ private async Task SendRequestAndReturnValidGeminiResponseAsync( return geminiResponse; } - /// Checks if a tool call is for a function that was defined. - private static bool IsRequestableTool(IEnumerable functions, GeminiFunctionToolCall ftc) - => functions.Any(geminiFunction => - string.Equals(geminiFunction.Name, ftc.FullyQualifiedName, StringComparison.OrdinalIgnoreCase)); - private void AddToolResponseMessage( ChatHistory chat, GeminiRequest request, @@ -591,8 +669,8 @@ private void LogUsage(List chatMessageContents) } private List GetChatMessageContentsFromResponse(GeminiResponse geminiResponse) - => geminiResponse.Candidates == null ? - [new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId)] + => geminiResponse.Candidates == null + ? [new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId)] : geminiResponse.Candidates.Select(candidate => this.GetChatMessageContentFromCandidate(geminiResponse, candidate)).ToList(); private GeminiChatMessageContent GetChatMessageContentFromCandidate(GeminiResponse geminiResponse, GeminiResponseCandidate candidate) @@ -663,17 +741,17 @@ private static void ValidateAutoInvoke(bool autoInvoke, int resultsPerPrompt) private static GeminiMetadata GetResponseMetadata( GeminiResponse geminiResponse, GeminiResponseCandidate candidate) => new() - { - FinishReason = candidate.FinishReason, - Index = candidate.Index, - PromptTokenCount = geminiResponse.UsageMetadata?.PromptTokenCount ?? 0, - CurrentCandidateTokenCount = candidate.TokenCount, - CandidatesTokenCount = geminiResponse.UsageMetadata?.CandidatesTokenCount ?? 0, - TotalTokenCount = geminiResponse.UsageMetadata?.TotalTokenCount ?? 0, - PromptFeedbackBlockReason = geminiResponse.PromptFeedback?.BlockReason, - PromptFeedbackSafetyRatings = geminiResponse.PromptFeedback?.SafetyRatings.ToList(), - ResponseSafetyRatings = candidate.SafetyRatings?.ToList(), - }; + { + FinishReason = candidate.FinishReason, + Index = candidate.Index, + PromptTokenCount = geminiResponse.UsageMetadata?.PromptTokenCount ?? 0, + CurrentCandidateTokenCount = candidate.TokenCount, + CandidatesTokenCount = geminiResponse.UsageMetadata?.CandidatesTokenCount ?? 0, + TotalTokenCount = geminiResponse.UsageMetadata?.TotalTokenCount ?? 0, + PromptFeedbackBlockReason = geminiResponse.PromptFeedback?.BlockReason, + PromptFeedbackSafetyRatings = geminiResponse.PromptFeedback?.SafetyRatings.ToList(), + ResponseSafetyRatings = candidate.SafetyRatings?.ToList(), + }; private sealed class ChatCompletionState { @@ -682,7 +760,7 @@ private sealed class ChatCompletionState internal Kernel Kernel { get; set; } = null!; internal GeminiPromptExecutionSettings ExecutionSettings { get; set; } = null!; internal GeminiChatMessageContent? LastMessage { get; set; } - internal int Iteration { get; set; } + internal int RequestIndex { get; set; } internal bool AutoInvoke { get; set; } internal void AddLastMessageToChatHistoryAndRequest() diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiFunctionCallingMode.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiFunctionCallingMode.cs new file mode 100644 index 000000000000..b692c9dedd83 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiFunctionCallingMode.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.SemanticKernel.Connectors.Google.Core; + +/// +/// Represents a Gemini Function Calling Mode. +/// +[JsonConverter(typeof(GeminiFunctionCallingModeConverter))] +internal readonly struct GeminiFunctionCallingMode : IEquatable +{ + /// + /// The default model behavior. The model decides to predict either a function call or a natural language response. + /// + public static GeminiFunctionCallingMode Default { get; } = new("AUTO"); + + /// + /// The model is constrained to always predict a function call. If allowed_function_names is not provided, + /// the model picks from all of the available function declarations. + /// If allowed_function_names is provided, the model picks from the set of allowed functions. + /// + public static GeminiFunctionCallingMode Any { get; } = new("ANY"); + + /// + /// The model won't predict a function call. In this case, the model behavior is the same as if you don't pass any function declarations. + /// + public static GeminiFunctionCallingMode None { get; } = new("NONE"); + + /// + /// Gets the label of the property. + /// Label is used for serialization. + /// + public string Label { get; } + + /// + /// Represents a Gemini Function Calling Mode. + /// + [JsonConstructor] + public GeminiFunctionCallingMode(string label) + { + Verify.NotNullOrWhiteSpace(label, nameof(label)); + this.Label = label; + } + + /// + /// Represents the equality operator for comparing two instances of . + /// + /// The left instance to compare. + /// The right instance to compare. + /// true if the two instances are equal; otherwise, false. + public static bool operator ==(GeminiFunctionCallingMode left, GeminiFunctionCallingMode right) + => left.Equals(right); + + /// + /// Represents the inequality operator for comparing two instances of . + /// + /// The left instance to compare. + /// The right instance to compare. + /// true if the two instances are not equal; otherwise, false. + public static bool operator !=(GeminiFunctionCallingMode left, GeminiFunctionCallingMode right) + => !(left == right); + + /// + public bool Equals(GeminiFunctionCallingMode other) + => string.Equals(this.Label, other.Label, StringComparison.OrdinalIgnoreCase); + + /// + public override bool Equals(object? obj) + => obj is GeminiFunctionCallingMode other && this == other; + + /// + public override int GetHashCode() + => StringComparer.OrdinalIgnoreCase.GetHashCode(this.Label ?? string.Empty); + + /// + public override string ToString() => this.Label ?? string.Empty; +} + +internal sealed class GeminiFunctionCallingModeConverter : JsonConverter +{ + public override GeminiFunctionCallingMode Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + => new(reader.GetString()!); + + public override void Write(Utf8JsonWriter writer, GeminiFunctionCallingMode value, JsonSerializerOptions options) + => writer.WriteStringValue(value.Label); +} diff --git a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs index 00821e9a2760..37a43ed2a9bb 100644 --- a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs @@ -6,7 +6,6 @@ using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; -using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Text; namespace Microsoft.SemanticKernel.Connectors.Google; @@ -25,7 +24,6 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings private IList? _stopSequences; private bool? _audioTimestamp; private IList? _safetySettings; - private GeminiToolCallBehavior? _toolCallBehavior; /// /// Default max tokens for a text generation. @@ -135,43 +133,6 @@ public IList? SafetySettings } } - /// - /// Gets or sets the behavior for how tool calls are handled. - /// - /// - /// - /// To disable all tool calling, set the property to null (the default). - /// - /// To allow the model to request one of any number of functions, set the property to an - /// instance returned from , called with - /// a list of the functions available. - /// - /// - /// To allow the model to request one of any of the functions in the supplied , - /// set the property to if the client should simply - /// send the information about the functions and not handle the response in any special manner, or - /// if the client should attempt to automatically - /// invoke the function and send the result back to the service. - /// - /// - /// For all options where an instance is provided, auto-invoke behavior may be selected. If the service - /// sends a request for a function call, if auto-invoke has been requested, the client will attempt to - /// resolve that function from the functions available in the , and if found, rather - /// than returning the response back to the caller, it will handle the request automatically, invoking - /// the function, and sending back the result. The intermediate messages will be retained in the - /// if an instance was provided. - /// - public GeminiToolCallBehavior? ToolCallBehavior - { - get => this._toolCallBehavior; - - set - { - this.ThrowIfFrozen(); - this._toolCallBehavior = value; - } - } - /// /// Indicates if the audio response should include timestamps. /// if enabled, audio timestamp will be included in the request to the model. @@ -222,7 +183,6 @@ public override PromptExecutionSettings Clone() CandidateCount = this.CandidateCount, StopSequences = this.StopSequences is not null ? new List(this.StopSequences) : null, SafetySettings = this.SafetySettings?.Select(setting => new GeminiSafetySetting(setting)).ToList(), - ToolCallBehavior = this.ToolCallBehavior?.Clone(), AudioTimestamp = this.AudioTimestamp }; } @@ -250,6 +210,11 @@ public static GeminiPromptExecutionSettings FromExecutionSettings(PromptExecutio } var json = JsonSerializer.Serialize(executionSettings); - return JsonSerializer.Deserialize(json, JsonOptionsCache.ReadPermissive)!; + var geminiPromptExecutionSettings = JsonSerializer.Deserialize(json, JsonOptionsCache.ReadPermissive)!; + + // Restore the function choice behavior that lost internal state(list of function instances) during serialization/deserialization process. + geminiPromptExecutionSettings.FunctionChoiceBehavior = executionSettings.FunctionChoiceBehavior; + + return geminiPromptExecutionSettings; } }