diff --git a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj
index 4d5a3deb9906..9ce032ddcb68 100644
--- a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj
+++ b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj
@@ -12,6 +12,7 @@
+
diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs
index 4b948140348f..bdd12c5123c5 100644
--- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs
+++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs
@@ -12,6 +12,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.Connectors.FunctionCalling;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Http;
using Microsoft.SemanticKernel.Text;
@@ -28,30 +29,10 @@ internal sealed class GeminiChatCompletionClient : ClientBase
private readonly string _modelId;
private readonly Uri _chatGenerationEndpoint;
private readonly Uri _chatStreamingEndpoint;
+ private readonly FunctionCallsProcessor _functionCallsProcessor;
private static readonly string s_namespace = typeof(GoogleAIGeminiChatCompletionService).Namespace!;
- ///
- /// The maximum number of auto-invokes that can be in-flight at any given time as part of the current
- /// asynchronous chain of execution.
- ///
- ///
- /// This is a fail-safe mechanism. If someone accidentally manages to set up execution settings in such a way that
- /// auto-invocation is invoked recursively, and in particular where a prompt function is able to auto-invoke itself,
- /// we could end up in an infinite loop. This const is a backstop against that happening. We should never come close
- /// to this limit, but if we do, auto-invoke will be disabled for the current flow in order to prevent runaway execution.
- /// With the current setup, the way this could possibly happen is if a prompt function is configured with built-in
- /// execution settings that opt-in to auto-invocation of everything in the kernel, in which case the invocation of that
- /// prompt function could advertise itself as a candidate for auto-invocation. We don't want to outright block that,
- /// if that's something a developer has asked to do (e.g. it might be invoked with different arguments than its parent
- /// was invoked with), but we do want to limit it. This limit is arbitrary and can be tweaked in the future and/or made
- /// configurable should need arise.
- ///
- private const int MaxInflightAutoInvokes = 128;
-
- /// Tracking for .
- private static readonly AsyncLocal s_inflightAutoInvokes = new();
-
///
/// Instance of for metrics.
///
@@ -84,6 +65,12 @@ internal sealed class GeminiChatCompletionClient : ClientBase
unit: "{token}",
description: "Number of tokens used");
+ private sealed record ToolCallingConfig(
+ IList? Tools,
+ GeminiFunctionCallingMode? Mode,
+ bool AutoInvoke,
+ FunctionChoiceBehaviorOptions? Options);
+
///
/// Represents a client for interacting with the chat completion Gemini model via GoogleAI.
///
@@ -108,6 +95,7 @@ public GeminiChatCompletionClient(
string versionSubLink = GetApiVersionSubLink(apiVersion);
this._modelId = modelId;
+ this._functionCallsProcessor = new FunctionCallsProcessor(this.Logger);
this._chatGenerationEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._modelId}:generateContent?key={apiKey}");
this._chatStreamingEndpoint = new Uri($"https://generativelanguage.googleapis.com/{versionSubLink}/models/{this._modelId}:streamGenerateContent?key={apiKey}&alt=sse");
}
@@ -142,6 +130,7 @@ public GeminiChatCompletionClient(
string versionSubLink = GetApiVersionSubLink(apiVersion);
this._modelId = modelId;
+ this._functionCallsProcessor = new FunctionCallsProcessor(this.Logger);
this._chatGenerationEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:generateContent");
this._chatStreamingEndpoint = new Uri($"https://{location}-aiplatform.googleapis.com/{versionSubLink}/projects/{projectId}/locations/{location}/publishers/google/models/{this._modelId}:streamGenerateContent?alt=sse");
}
@@ -162,11 +151,16 @@ public async Task> GenerateChatMessageAsync(
{
var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings);
- for (state.Iteration = 1; ; state.Iteration++)
+ for (state.RequestIndex = 0;; state.RequestIndex++)
{
+ // TODO: do something with this variable
+ var functionCallingConfig = this.GetFunctionCallingConfiguration(state);
+
+ // TODO: Here should be request created not above loop
+
List chatResponses;
using (var activity = ModelDiagnostics.StartCompletionActivity(
- this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
+ this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
{
GeminiResponse geminiResponse;
try
@@ -190,22 +184,38 @@ public async Task> GenerateChatMessageAsync(
// If we don't want to attempt to invoke any functions, just return the result.
// Or if we are auto-invoking but we somehow end up with other than 1 choice even though only 1 was requested, similarly bail.
- if (!state.AutoInvoke || chatResponses.Count != 1)
+ if (!state.AutoInvoke || chatResponses.Count == 0)
{
return chatResponses;
}
state.LastMessage = chatResponses[0];
+ // TODO: will ToolCalls property shoul be removed from GeminiChatMessageContent?
if (state.LastMessage.ToolCalls is null)
{
return chatResponses;
}
- // ToolCallBehavior is not null because we are in auto-invoke mode but we check it again to be sure it wasn't changed in the meantime
- Verify.NotNull(state.ExecutionSettings.ToolCallBehavior);
-
- state.AddLastMessageToChatHistoryAndRequest();
- await this.ProcessFunctionsAsync(state, cancellationToken).ConfigureAwait(false);
+ // TODO: to remove?
+ // state.AddLastMessageToChatHistoryAndRequest();
+
+ // Process function calls by invoking the functions and adding the results to the chat history.
+ // Each function call will trigger auto-function-invocation filters, which can terminate the process.
+ // In such cases, we'll return the last message in the chat history.
+ var lastMessage = await this._functionCallsProcessor.ProcessFunctionCallsAsync(
+ state.LastMessage,
+ chatHistory,
+ state.RequestIndex,
+ content => IsRequestableTool(state.LastMessage.ToolCalls, content),
+ functionCallingConfig.Options ?? new FunctionChoiceBehaviorOptions(),
+ kernel,
+ isStreaming: false,
+ cancellationToken).ConfigureAwait(false);
+
+ if (lastMessage != null)
+ {
+ return [lastMessage];
+ }
}
}
@@ -225,10 +235,10 @@ public async IAsyncEnumerable StreamGenerateChatMes
{
var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings);
- for (state.Iteration = 1; ; state.Iteration++)
+ for (state.RequestIndex = 1;; state.RequestIndex++)
{
using (var activity = ModelDiagnostics.StartCompletionActivity(
- this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
+ this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
{
HttpResponseMessage? httpResponseMessage = null;
Stream? responseStream = null;
@@ -292,6 +302,79 @@ public async IAsyncEnumerable StreamGenerateChatMes
}
}
+ private ToolCallingConfig GetFunctionCallingConfiguration(ChatCompletionState state)
+ {
+ // If neither behavior is specified, we just return default configuration with no tool and no choice
+ if (state.ExecutionSettings.FunctionChoiceBehavior is null)
+ {
+ return new ToolCallingConfig(Tools: null, Mode: null, AutoInvoke: false, Options: null);
+ }
+
+ return this.ConfigureFunctionCalling(state);
+ }
+
+ private ToolCallingConfig ConfigureFunctionCalling(ChatCompletionState state)
+ {
+ var config =
+ this._functionCallsProcessor.GetConfiguration(state.ExecutionSettings.FunctionChoiceBehavior, state.ChatHistory, state.RequestIndex, state.Kernel);
+
+ IList? tools = null;
+ GeminiFunctionCallingMode? toolMode = null;
+ bool autoInvoke = config?.AutoInvoke ?? false;
+
+ if (config?.Functions is { Count: > 0 } functions)
+ {
+ if (config.Choice == FunctionChoice.Auto)
+ {
+ toolMode = GeminiFunctionCallingMode.Default;
+ }
+ else if (config.Choice == FunctionChoice.Required)
+ {
+ toolMode = GeminiFunctionCallingMode.Any;
+ }
+ else if (config.Choice == FunctionChoice.None)
+ {
+ toolMode = GeminiFunctionCallingMode.None;
+ }
+ else
+ {
+ throw new NotSupportedException($"Unsupported function choice '{config.Choice}'.");
+ }
+
+ tools = [];
+
+ foreach (var function in functions)
+ {
+ tools.Add(function.Metadata.ToOpenAIFunction().ToFunctionDefinition());
+ }
+ }
+
+ return new ToolCallingConfig(
+ Tools: tools,
+ Mode: toolMode ?? GeminiFunctionCallingMode.None,
+ AutoInvoke: autoInvoke,
+ Options: config?.Options);
+ }
+
+ /// Checks if a tool call is for a function that was defined.
+ private static bool IsRequestableTool(IReadOnlyList tools, FunctionCallContent functionCallContent)
+ {
+ foreach (var tool in tools)
+ {
+ if (string.Equals(tool.FunctionName,
+ FunctionName.ToFullyQualifiedName(
+ functionCallContent.FunctionName,
+ functionCallContent.PluginName,
+ GeminiFunction.NameSeparator),
+ StringComparison.OrdinalIgnoreCase))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
private ChatCompletionState ValidateInputAndCreateChatCompletionState(
ChatHistory chatHistory,
Kernel? kernel,
@@ -391,7 +474,7 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation
// Clear the tools. If we end up wanting to use tools, we'll reset it to the desired value.
state.GeminiRequest.Tools = null;
- if (state.Iteration >= state.ExecutionSettings.ToolCallBehavior!.MaximumUseAttempts)
+ if (state.RequestIndex >= state.ExecutionSettings.ToolCallBehavior!.MaximumUseAttempts)
{
// Don't add any tools as we've reached the maximum attempts limit.
if (this.Logger.IsEnabled(LogLevel.Debug))
@@ -408,7 +491,7 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation
}
// Disable auto invocation if we've exceeded the allowed limit.
- if (state.Iteration >= state.ExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts)
+ if (state.RequestIndex >= state.ExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts)
{
state.AutoInvoke = false;
if (this.Logger.IsEnabled(LogLevel.Debug))
@@ -481,11 +564,6 @@ private async Task SendRequestAndReturnValidGeminiResponseAsync(
return geminiResponse;
}
- /// Checks if a tool call is for a function that was defined.
- private static bool IsRequestableTool(IEnumerable functions, GeminiFunctionToolCall ftc)
- => functions.Any(geminiFunction =>
- string.Equals(geminiFunction.Name, ftc.FullyQualifiedName, StringComparison.OrdinalIgnoreCase));
-
private void AddToolResponseMessage(
ChatHistory chat,
GeminiRequest request,
@@ -591,8 +669,8 @@ private void LogUsage(List chatMessageContents)
}
private List GetChatMessageContentsFromResponse(GeminiResponse geminiResponse)
- => geminiResponse.Candidates == null ?
- [new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId)]
+ => geminiResponse.Candidates == null
+ ? [new GeminiChatMessageContent(role: AuthorRole.Assistant, content: string.Empty, modelId: this._modelId)]
: geminiResponse.Candidates.Select(candidate => this.GetChatMessageContentFromCandidate(geminiResponse, candidate)).ToList();
private GeminiChatMessageContent GetChatMessageContentFromCandidate(GeminiResponse geminiResponse, GeminiResponseCandidate candidate)
@@ -663,17 +741,17 @@ private static void ValidateAutoInvoke(bool autoInvoke, int resultsPerPrompt)
private static GeminiMetadata GetResponseMetadata(
GeminiResponse geminiResponse,
GeminiResponseCandidate candidate) => new()
- {
- FinishReason = candidate.FinishReason,
- Index = candidate.Index,
- PromptTokenCount = geminiResponse.UsageMetadata?.PromptTokenCount ?? 0,
- CurrentCandidateTokenCount = candidate.TokenCount,
- CandidatesTokenCount = geminiResponse.UsageMetadata?.CandidatesTokenCount ?? 0,
- TotalTokenCount = geminiResponse.UsageMetadata?.TotalTokenCount ?? 0,
- PromptFeedbackBlockReason = geminiResponse.PromptFeedback?.BlockReason,
- PromptFeedbackSafetyRatings = geminiResponse.PromptFeedback?.SafetyRatings.ToList(),
- ResponseSafetyRatings = candidate.SafetyRatings?.ToList(),
- };
+ {
+ FinishReason = candidate.FinishReason,
+ Index = candidate.Index,
+ PromptTokenCount = geminiResponse.UsageMetadata?.PromptTokenCount ?? 0,
+ CurrentCandidateTokenCount = candidate.TokenCount,
+ CandidatesTokenCount = geminiResponse.UsageMetadata?.CandidatesTokenCount ?? 0,
+ TotalTokenCount = geminiResponse.UsageMetadata?.TotalTokenCount ?? 0,
+ PromptFeedbackBlockReason = geminiResponse.PromptFeedback?.BlockReason,
+ PromptFeedbackSafetyRatings = geminiResponse.PromptFeedback?.SafetyRatings.ToList(),
+ ResponseSafetyRatings = candidate.SafetyRatings?.ToList(),
+ };
private sealed class ChatCompletionState
{
@@ -682,7 +760,7 @@ private sealed class ChatCompletionState
internal Kernel Kernel { get; set; } = null!;
internal GeminiPromptExecutionSettings ExecutionSettings { get; set; } = null!;
internal GeminiChatMessageContent? LastMessage { get; set; }
- internal int Iteration { get; set; }
+ internal int RequestIndex { get; set; }
internal bool AutoInvoke { get; set; }
internal void AddLastMessageToChatHistoryAndRequest()
diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiFunctionCallingMode.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiFunctionCallingMode.cs
new file mode 100644
index 000000000000..b692c9dedd83
--- /dev/null
+++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiFunctionCallingMode.cs
@@ -0,0 +1,89 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace Microsoft.SemanticKernel.Connectors.Google.Core;
+
+///
+/// Represents a Gemini Function Calling Mode.
+///
+[JsonConverter(typeof(GeminiFunctionCallingModeConverter))]
+internal readonly struct GeminiFunctionCallingMode : IEquatable
+{
+ ///
+ /// The default model behavior. The model decides to predict either a function call or a natural language response.
+ ///
+ public static GeminiFunctionCallingMode Default { get; } = new("AUTO");
+
+ ///
+ /// The model is constrained to always predict a function call. If allowed_function_names is not provided,
+ /// the model picks from all of the available function declarations.
+ /// If allowed_function_names is provided, the model picks from the set of allowed functions.
+ ///
+ public static GeminiFunctionCallingMode Any { get; } = new("ANY");
+
+ ///
+ /// The model won't predict a function call. In this case, the model behavior is the same as if you don't pass any function declarations.
+ ///
+ public static GeminiFunctionCallingMode None { get; } = new("NONE");
+
+ ///
+ /// Gets the label of the property.
+ /// Label is used for serialization.
+ ///
+ public string Label { get; }
+
+ ///
+ /// Represents a Gemini Function Calling Mode.
+ ///
+ [JsonConstructor]
+ public GeminiFunctionCallingMode(string label)
+ {
+ Verify.NotNullOrWhiteSpace(label, nameof(label));
+ this.Label = label;
+ }
+
+ ///
+ /// Represents the equality operator for comparing two instances of .
+ ///
+ /// The left instance to compare.
+ /// The right instance to compare.
+ /// true if the two instances are equal; otherwise, false.
+ public static bool operator ==(GeminiFunctionCallingMode left, GeminiFunctionCallingMode right)
+ => left.Equals(right);
+
+ ///
+ /// Represents the inequality operator for comparing two instances of .
+ ///
+ /// The left instance to compare.
+ /// The right instance to compare.
+ /// true if the two instances are not equal; otherwise, false.
+ public static bool operator !=(GeminiFunctionCallingMode left, GeminiFunctionCallingMode right)
+ => !(left == right);
+
+ ///
+ public bool Equals(GeminiFunctionCallingMode other)
+ => string.Equals(this.Label, other.Label, StringComparison.OrdinalIgnoreCase);
+
+ ///
+ public override bool Equals(object? obj)
+ => obj is GeminiFunctionCallingMode other && this == other;
+
+ ///
+ public override int GetHashCode()
+ => StringComparer.OrdinalIgnoreCase.GetHashCode(this.Label ?? string.Empty);
+
+ ///
+ public override string ToString() => this.Label ?? string.Empty;
+}
+
+internal sealed class GeminiFunctionCallingModeConverter : JsonConverter
+{
+ public override GeminiFunctionCallingMode Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ => new(reader.GetString()!);
+
+ public override void Write(Utf8JsonWriter writer, GeminiFunctionCallingMode value, JsonSerializerOptions options)
+ => writer.WriteStringValue(value.Label);
+}
diff --git a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs
index 00821e9a2760..37a43ed2a9bb 100644
--- a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs
+++ b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs
@@ -6,7 +6,6 @@
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
-using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Text;
namespace Microsoft.SemanticKernel.Connectors.Google;
@@ -25,7 +24,6 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings
private IList? _stopSequences;
private bool? _audioTimestamp;
private IList? _safetySettings;
- private GeminiToolCallBehavior? _toolCallBehavior;
///
/// Default max tokens for a text generation.
@@ -135,43 +133,6 @@ public IList? SafetySettings
}
}
- ///
- /// Gets or sets the behavior for how tool calls are handled.
- ///
- ///
- ///
- /// - To disable all tool calling, set the property to null (the default).
- /// -
- /// To allow the model to request one of any number of functions, set the property to an
- /// instance returned from , called with
- /// a list of the functions available.
- ///
- /// -
- /// To allow the model to request one of any of the functions in the supplied ,
- /// set the property to if the client should simply
- /// send the information about the functions and not handle the response in any special manner, or
- /// if the client should attempt to automatically
- /// invoke the function and send the result back to the service.
- ///
- ///
- /// For all options where an instance is provided, auto-invoke behavior may be selected. If the service
- /// sends a request for a function call, if auto-invoke has been requested, the client will attempt to
- /// resolve that function from the functions available in the , and if found, rather
- /// than returning the response back to the caller, it will handle the request automatically, invoking
- /// the function, and sending back the result. The intermediate messages will be retained in the
- /// if an instance was provided.
- ///
- public GeminiToolCallBehavior? ToolCallBehavior
- {
- get => this._toolCallBehavior;
-
- set
- {
- this.ThrowIfFrozen();
- this._toolCallBehavior = value;
- }
- }
-
///
/// Indicates if the audio response should include timestamps.
/// if enabled, audio timestamp will be included in the request to the model.
@@ -222,7 +183,6 @@ public override PromptExecutionSettings Clone()
CandidateCount = this.CandidateCount,
StopSequences = this.StopSequences is not null ? new List(this.StopSequences) : null,
SafetySettings = this.SafetySettings?.Select(setting => new GeminiSafetySetting(setting)).ToList(),
- ToolCallBehavior = this.ToolCallBehavior?.Clone(),
AudioTimestamp = this.AudioTimestamp
};
}
@@ -250,6 +210,11 @@ public static GeminiPromptExecutionSettings FromExecutionSettings(PromptExecutio
}
var json = JsonSerializer.Serialize(executionSettings);
- return JsonSerializer.Deserialize(json, JsonOptionsCache.ReadPermissive)!;
+ var geminiPromptExecutionSettings = JsonSerializer.Deserialize(json, JsonOptionsCache.ReadPermissive)!;
+
+ // Restore the function choice behavior that lost internal state(list of function instances) during serialization/deserialization process.
+ geminiPromptExecutionSettings.FunctionChoiceBehavior = executionSettings.FunctionChoiceBehavior;
+
+ return geminiPromptExecutionSettings;
}
}