diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs index 55283d191a84..aa530c8a8c72 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs @@ -470,6 +470,44 @@ public void AddChatMessageToRequest() c => Equals(message.Role, c.Role)); } + [Fact] + public void CachedContentFromPromptReturnsAsExpected() + { + // Arrange + var prompt = "prompt-example"; + var executionSettings = new GeminiPromptExecutionSettings + { + CachedContent = "xyz/abc" + }; + + // Act + var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings); + + // Assert + Assert.NotNull(request.Configuration); + Assert.Equal(executionSettings.CachedContent, request.CachedContent); + } + + [Fact] + public void CachedContentFromChatHistoryReturnsAsExpected() + { + // Arrange + ChatHistory chatHistory = []; + chatHistory.AddUserMessage("user-message"); + chatHistory.AddAssistantMessage("assist-message"); + chatHistory.AddUserMessage("user-message2"); + var executionSettings = new GeminiPromptExecutionSettings + { + CachedContent = "xyz/abc" + }; + + // Act + var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings); + + // Assert + Assert.Equal(executionSettings.CachedContent, request.CachedContent); + } + private sealed class DummyContent(object? innerContent, string? modelId = null, IReadOnlyDictionary? metadata = null) : KernelContent(innerContent, modelId, metadata); } diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleAIGeminiChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleAIGeminiChatCompletionServiceTests.cs index 1d9bb5d6377d..0d986d21ca5a 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleAIGeminiChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleAIGeminiChatCompletionServiceTests.cs @@ -1,13 +1,34 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.IO; +using System.Net.Http; +using System.Text; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Google; using Microsoft.SemanticKernel.Services; using Xunit; namespace SemanticKernel.Connectors.Google.UnitTests.Services; -public sealed class GoogleAIGeminiChatCompletionServiceTests +public sealed class GoogleAIGeminiChatCompletionServiceTests : IDisposable { + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public GoogleAIGeminiChatCompletionServiceTests() + { + this._messageHandlerStub = new() + { + ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { + Content = new StringContent(File.ReadAllText("./TestData/completion_one_response.json")) + } + }; + this._httpClient = new HttpClient(this._messageHandlerStub, false); + } + [Fact] public void AttributesShouldContainModelId() { @@ -18,4 +39,39 @@ public void AttributesShouldContainModelId() // Assert Assert.Equal(model, service.Attributes[AIServiceExtensions.ModelIdKey]); } + + [Theory] + [InlineData(null)] + [InlineData("content")] + [InlineData("")] + public async Task RequestCachedContentWorksCorrectlyAsync(string? cachedContent) + { + // Arrange + string model = "fake-model"; + var sut = new GoogleAIGeminiChatCompletionService(model, "key", httpClient: this._httpClient); + + // Act + var result = await sut.GetChatMessageContentAsync("my prompt", new GeminiPromptExecutionSettings { CachedContent = cachedContent }); + + // Assert + Assert.NotNull(result); + Assert.NotNull(this._messageHandlerStub.RequestContent); + + var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent); + if (cachedContent is not null) + { + Assert.Contains($"\"cachedContent\":\"{cachedContent}\"", requestBody); + } + else + { + // Then no quality is provided, it should not be included in the request body + Assert.DoesNotContain("cachedContent", requestBody); + } + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } } diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/VertexAIGeminiChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/VertexAIGeminiChatCompletionServiceTests.cs index 89e65fbaa534..0376924c0e91 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/VertexAIGeminiChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/VertexAIGeminiChatCompletionServiceTests.cs @@ -1,14 +1,34 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.IO; +using System.Net.Http; +using System.Text; using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Google; using Microsoft.SemanticKernel.Services; using Xunit; namespace SemanticKernel.Connectors.Google.UnitTests.Services; -public sealed class VertexAIGeminiChatCompletionServiceTests +public sealed class VertexAIGeminiChatCompletionServiceTests : IDisposable { + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public VertexAIGeminiChatCompletionServiceTests() + { + this._messageHandlerStub = new() + { + ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { + Content = new StringContent(File.ReadAllText("./TestData/completion_one_response.json")) + } + }; + this._httpClient = new HttpClient(this._messageHandlerStub, false); + } + [Fact] public void AttributesShouldContainModelIdBearerAsString() { @@ -30,4 +50,39 @@ public void AttributesShouldContainModelIdBearerAsFunc() // Assert Assert.Equal(model, service.Attributes[AIServiceExtensions.ModelIdKey]); } + + [Theory] + [InlineData(null)] + [InlineData("content")] + [InlineData("")] + public async Task RequestCachedContentWorksCorrectlyAsync(string? cachedContent) + { + // Arrange + string model = "fake-model"; + var sut = new VertexAIGeminiChatCompletionService(model, () => new ValueTask("key"), "location", "project", httpClient: this._httpClient); + + // Act + var result = await sut.GetChatMessageContentAsync("my prompt", new GeminiPromptExecutionSettings { CachedContent = cachedContent }); + + // Assert + Assert.NotNull(result); + Assert.NotNull(this._messageHandlerStub.RequestContent); + + var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent); + if (cachedContent is not null) + { + Assert.Contains($"\"cachedContent\":\"{cachedContent}\"", requestBody); + } + else + { + // Then no quality is provided, it should not be included in the request body + Assert.DoesNotContain("cachedContent", requestBody); + } + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } } diff --git a/dotnet/src/Connectors/Connectors.Google/Core/ClientBase.cs b/dotnet/src/Connectors/Connectors.Google/Core/ClientBase.cs index 5d465f5d590f..b94ca9eeebc6 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/ClientBase.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/ClientBase.cs @@ -112,6 +112,7 @@ protected static string GetApiVersionSubLink(VertexAIVersion apiVersion) => apiVersion switch { VertexAIVersion.V1 => "v1", + VertexAIVersion.V1_Beta => "v1beta1", _ => throw new NotSupportedException($"Vertex API version {apiVersion} is not supported.") }; } diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs index 2ebda2c2a0de..0ff0b5d10bf0 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs @@ -42,6 +42,10 @@ internal sealed class GeminiRequest [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public GeminiContent? SystemInstruction { get; set; } + [JsonPropertyName("cachedContent")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? CachedContent { get; set; } + public void AddFunction(GeminiFunction function) { // NOTE: Currently Gemini only supports one tool i.e. function calling. @@ -67,6 +71,7 @@ public static GeminiRequest FromPromptAndExecutionSettings( GeminiRequest obj = CreateGeminiRequest(prompt); AddSafetySettings(executionSettings, obj); AddConfiguration(executionSettings, obj); + AddAdditionalBodyFields(executionSettings, obj); return obj; } @@ -83,6 +88,7 @@ public static GeminiRequest FromChatHistoryAndExecutionSettings( GeminiRequest obj = CreateGeminiRequest(chatHistory); AddSafetySettings(executionSettings, obj); AddConfiguration(executionSettings, obj); + AddAdditionalBodyFields(executionSettings, obj); return obj; } @@ -318,6 +324,11 @@ private static void AddSafetySettings(GeminiPromptExecutionSettings executionSet => new GeminiSafetySetting(s.Category, s.Threshold)).ToList(); } + private static void AddAdditionalBodyFields(GeminiPromptExecutionSettings executionSettings, GeminiRequest request) + { + request.CachedContent = executionSettings.CachedContent; + } + internal sealed class ConfigurationElement { [JsonPropertyName("temperature")] diff --git a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs index fab00f01e11d..daa8ea629a5e 100644 --- a/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.Google/GeminiPromptExecutionSettings.cs @@ -27,6 +27,7 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings private bool? _audioTimestamp; private string? _responseMimeType; private object? _responseSchema; + private string? _cachedContent; private IList? _safetySettings; private GeminiToolCallBehavior? _toolCallBehavior; @@ -41,6 +42,7 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings /// Range is 0.0 to 1.0. /// [JsonPropertyName("temperature")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public double? Temperature { get => this._temperature; @@ -56,6 +58,7 @@ public double? Temperature /// The higher the TopP, the more diverse the completion. /// [JsonPropertyName("top_p")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public double? TopP { get => this._topP; @@ -71,6 +74,7 @@ public double? TopP /// The TopK property represents the maximum value of a collection or dataset. /// [JsonPropertyName("top_k")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public int? TopK { get => this._topK; @@ -85,6 +89,7 @@ public int? TopK /// The maximum number of tokens to generate in the completion. /// [JsonPropertyName("max_tokens")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public int? MaxTokens { get => this._maxTokens; @@ -99,6 +104,7 @@ public int? MaxTokens /// The count of candidates. Possible values range from 1 to 8. /// [JsonPropertyName("candidate_count")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public int? CandidateCount { get => this._candidateCount; @@ -114,6 +120,7 @@ public int? CandidateCount /// Maximum number of stop sequences is 5. /// [JsonPropertyName("stop_sequences")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public IList? StopSequences { get => this._stopSequences; @@ -128,6 +135,7 @@ public IList? StopSequences /// Represents a list of safety settings. /// [JsonPropertyName("safety_settings")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public IList? SafetySettings { get => this._safetySettings; @@ -180,6 +188,7 @@ public GeminiToolCallBehavior? ToolCallBehavior /// if enabled, audio timestamp will be included in the request to the model. /// [JsonPropertyName("audio_timestamp")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public bool? AudioTimestamp { get => this._audioTimestamp; @@ -198,6 +207,7 @@ public bool? AudioTimestamp /// 3. text/x.enum: For classification tasks, output an enum value as defined in the response schema. /// [JsonPropertyName("response_mimetype")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public string? ResponseMimeType { get => this._responseMimeType; @@ -234,6 +244,23 @@ public object? ResponseSchema } } + /// + /// Optional. The name of the cached content used as context to serve the prediction. + /// Note: only used in explicit caching, where users can have control over caching (e.g. what content to cache) and enjoy guaranteed cost savings. + /// Format: projects/{project}/locations/{location}/cachedContents/{cachedContent} + /// + [JsonPropertyName("cached_content")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? CachedContent + { + get => this._cachedContent; + set + { + this.ThrowIfFrozen(); + this._cachedContent = value; + } + } + /// public override void Freeze() { diff --git a/dotnet/src/Connectors/Connectors.Google/VertexAIVersion.cs b/dotnet/src/Connectors/Connectors.Google/VertexAIVersion.cs index 8e0a894e9f90..998910d8db42 100644 --- a/dotnet/src/Connectors/Connectors.Google/VertexAIVersion.cs +++ b/dotnet/src/Connectors/Connectors.Google/VertexAIVersion.cs @@ -12,5 +12,10 @@ public enum VertexAIVersion /// /// Represents the V1 version of the Vertex AI API. /// - V1 + V1, + + /// + /// Represents the V1-beta version of the Vertex AI API. + /// + V1_Beta } diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs index 5732a3e4719a..009d8f9bb30b 100644 --- a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs @@ -3,11 +3,15 @@ using System; using System.IO; using System.Linq; +using System.Net.Http; +using System.Net.Http.Json; +using System.Text; using System.Text.Json; using System.Threading.Tasks; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Google; +using Newtonsoft.Json.Linq; using xRetry; using Xunit; using Xunit.Abstractions; @@ -135,6 +139,61 @@ public async Task ChatGenerationWithSystemMessagesAsync(ServiceType serviceType) Assert.Contains("Roger", response.Content, StringComparison.OrdinalIgnoreCase); } + [RetryTheory] + [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] + public async Task ChatGenerationWithCachedContentAsync(ServiceType serviceType) + { + // Arrange + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Finish this sentence: He knew the sea’s..."); + + // Setup initial cached content + var cachedContentJson = File.ReadAllText(Path.Combine("Resources", "gemini_cached_content.json")) + .Replace("{{project}}", this.VertexAIGetProjectId()) + .Replace("{{location}}", this.VertexAIGetLocation()) + .Replace("{{model}}", this.VertexAIGetGeminiModel()); + + var cachedContentName = string.Empty; + + using (var httpClient = new HttpClient() + { + DefaultRequestHeaders = { Authorization = new("Bearer", this.VertexAIGetBearerKey()) } + }) + { + using (var content = new StringContent(cachedContentJson, Encoding.UTF8, "application/json")) + { + using (var httpResponse = await httpClient.PostAsync( + new Uri($"https://{this.VertexAIGetLocation()}-aiplatform.googleapis.com/v1beta1/projects/{this.VertexAIGetProjectId()}/locations/{this.VertexAIGetLocation()}/cachedContents"), + content)) + { + httpResponse.EnsureSuccessStatusCode(); + + var responseString = await httpResponse.Content.ReadAsStringAsync(); + var responseJson = JObject.Parse(responseString); + + cachedContentName = responseJson?["name"]?.ToString(); + + Assert.NotNull(cachedContentName); + } + } + } + + var sut = this.GetChatService(serviceType, isBeta: true); + + // Act + var response = await sut.GetChatMessageContentAsync( + chatHistory, + new GeminiPromptExecutionSettings + { + CachedContent = cachedContentName + }); + + // Assert + Assert.NotNull(response.Content); + this.Output.WriteLine(response.Content); + Assert.Contains("capriciousness", response.Content, StringComparison.OrdinalIgnoreCase); + } + [RetryTheory] [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] diff --git a/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs b/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs index 6b932727f4a6..eb7e42114053 100644 --- a/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs +++ b/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Net.Http; +using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Google; @@ -20,16 +22,18 @@ public abstract class TestsBase(ITestOutputHelper output) protected ITestOutputHelper Output { get; } = output; - protected IChatCompletionService GetChatService(ServiceType serviceType) => serviceType switch + protected IChatCompletionService GetChatService(ServiceType serviceType, bool isBeta = false) => serviceType switch { ServiceType.GoogleAI => new GoogleAIGeminiChatCompletionService( this.GoogleAIGetGeminiModel(), - this.GoogleAIGetApiKey()), + this.GoogleAIGetApiKey(), + isBeta ? GoogleAIVersion.V1_Beta : GoogleAIVersion.V1), ServiceType.VertexAI => new VertexAIGeminiChatCompletionService( modelId: this.VertexAIGetGeminiModel(), bearerKey: this.VertexAIGetBearerKey(), location: this.VertexAIGetLocation(), - projectId: this.VertexAIGetProjectId()), + projectId: this.VertexAIGetProjectId(), + isBeta ? VertexAIVersion.V1_Beta : VertexAIVersion.V1), _ => throw new ArgumentOutOfRangeException(nameof(serviceType), serviceType, null) }; @@ -69,10 +73,10 @@ public enum ServiceType private string GoogleAIGetGeminiVisionModel() => this._configuration.GetSection("GoogleAI:Gemini:VisionModelId").Get()!; private string GoogleAIGetEmbeddingModel() => this._configuration.GetSection("GoogleAI:EmbeddingModelId").Get()!; private string GoogleAIGetApiKey() => this._configuration.GetSection("GoogleAI:ApiKey").Get()!; - private string VertexAIGetGeminiModel() => this._configuration.GetSection("VertexAI:Gemini:ModelId").Get()!; + internal string VertexAIGetGeminiModel() => this._configuration.GetSection("VertexAI:Gemini:ModelId").Get()!; private string VertexAIGetGeminiVisionModel() => this._configuration.GetSection("VertexAI:Gemini:VisionModelId").Get()!; private string VertexAIGetEmbeddingModel() => this._configuration.GetSection("VertexAI:EmbeddingModelId").Get()!; - private string VertexAIGetBearerKey() => this._configuration.GetSection("VertexAI:BearerKey").Get()!; - private string VertexAIGetLocation() => this._configuration.GetSection("VertexAI:Location").Get()!; - private string VertexAIGetProjectId() => this._configuration.GetSection("VertexAI:ProjectId").Get()!; + internal string VertexAIGetBearerKey() => this._configuration.GetSection("VertexAI:BearerKey").Get()!; + internal string VertexAIGetLocation() => this._configuration.GetSection("VertexAI:Location").Get()!; + internal string VertexAIGetProjectId() => this._configuration.GetSection("VertexAI:ProjectId").Get()!; } diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index e24215b583d6..cd4f12741f96 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -195,4 +195,10 @@ Always + + + + Always + + \ No newline at end of file diff --git a/dotnet/src/IntegrationTests/Resources/gemini_cached_content.json b/dotnet/src/IntegrationTests/Resources/gemini_cached_content.json new file mode 100644 index 000000000000..fa5e4f688efc --- /dev/null +++ b/dotnet/src/IntegrationTests/Resources/gemini_cached_content.json @@ -0,0 +1,22 @@ +{ + "model": "projects/{{project}}/locations/{{location}}/publishers/google/models/{{model}}", + "displayName": "CACHE_DISPLAY_NAME", + "contents": [ + { + "role": "assistant", + "parts": [ + { + "text": "This is sample text to demonstrate explicit caching." + } + ] + }, + { + "role": "user", + "parts": [ + { + "text": "The old lighthouse keeper, Silas, squinted at the churning grey sea, his weathered face mirroring the granite rocks below. He’d seen countless storms, each one a furious dance of wind and wave, but tonight felt different, a simmering unease prickling his skin. The lantern, his steadfast companion, pulsed its rhythmic beam, a fragile defiance against the encroaching darkness. A small boat, barely visible through the swirling mist, was bucking against the tide, its lone mast a broken finger pointing towards the sky. Silas grabbed his oilskins, his movements stiff with age, and descended the winding stairs, his heart thumping a frantic rhythm against his ribs. He knew the sea’s capriciousness, its ability to lull and then lash out with brutal force." + } + ] + } + ] +} \ No newline at end of file