diff --git a/LangChain.sln b/LangChain.sln index ddcf6194..37c12bf8 100644 --- a/LangChain.sln +++ b/LangChain.sln @@ -360,6 +360,12 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Groq", "Groq", "{5DEC2707-D EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Groq.Tests", "src\Providers\Groq\test\LangChain.Providers.Groq.Tests.csproj", "{CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "VertexAI", "VertexAI", "{1F6C3FD4-959B-4A6E-ACCA-C0466149EC05}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.VertexAI", "src\Providers\Google.VertexAI\src\LangChain.Providers.Google.VertexAI.csproj", "{F5B1B04A-F72B-44D2-8B25-BB3717146D00}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.VertexAI.Tests", "src\Providers\Google.VertexAI\test\LangChain.Providers.Google.VertexAI.Tests.csproj", "{EFD4C813-47A1-4B40-8A23-D660B82B8938}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -734,6 +740,14 @@ Global {CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E}.Debug|Any CPU.Build.0 = Debug|Any CPU {CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E}.Release|Any CPU.ActiveCfg = Release|Any CPU {CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E}.Release|Any CPU.Build.0 = Release|Any CPU + {F5B1B04A-F72B-44D2-8B25-BB3717146D00}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F5B1B04A-F72B-44D2-8B25-BB3717146D00}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F5B1B04A-F72B-44D2-8B25-BB3717146D00}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F5B1B04A-F72B-44D2-8B25-BB3717146D00}.Release|Any CPU.Build.0 = Release|Any CPU + {EFD4C813-47A1-4B40-8A23-D660B82B8938}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EFD4C813-47A1-4B40-8A23-D660B82B8938}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EFD4C813-47A1-4B40-8A23-D660B82B8938}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EFD4C813-47A1-4B40-8A23-D660B82B8938}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -878,6 +892,9 @@ Global {FD0A56AD-AFB4-4A21-99C1-9BE5D074EB56} = {5DEC2707-DD62-4DCD-AC5D-6670AC2A1B01} {5DEC2707-DD62-4DCD-AC5D-6670AC2A1B01} = {E2B9833C-0397-4FAF-A3A8-116E58749750} {CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E} = {5DEC2707-DD62-4DCD-AC5D-6670AC2A1B01} + {1F6C3FD4-959B-4A6E-ACCA-C0466149EC05} = {A23BD019-BE70-42D5-9DD4-A79DEDDE54F4} + {F5B1B04A-F72B-44D2-8B25-BB3717146D00} = {1F6C3FD4-959B-4A6E-ACCA-C0466149EC05} + {EFD4C813-47A1-4B40-8A23-D660B82B8938} = {1F6C3FD4-959B-4A6E-ACCA-C0466149EC05} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {5C00D0F1-6138-4ED9-846B-97E43D6DFF1C} diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index b3f2b967..ec5a67ac 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -19,6 +19,7 @@ + diff --git a/src/Providers/Google.VertexAI/src/LangChain.Providers.Google.VertexAI.csproj b/src/Providers/Google.VertexAI/src/LangChain.Providers.Google.VertexAI.csproj new file mode 100644 index 00000000..91f4bd1a --- /dev/null +++ b/src/Providers/Google.VertexAI/src/LangChain.Providers.Google.VertexAI.csproj @@ -0,0 +1,16 @@ + + + + netstandard2.0 + $(NoWarn);CS3003 + + + + + + + + + + + diff --git a/src/Providers/Google.VertexAI/src/Predefined/GeminiModels.cs b/src/Providers/Google.VertexAI/src/Predefined/GeminiModels.cs new file mode 100644 index 00000000..0a0f8dad --- /dev/null +++ b/src/Providers/Google.VertexAI/src/Predefined/GeminiModels.cs @@ -0,0 +1,11 @@ +namespace LangChain.Providers.Google.VertexAI.Predefined +{ + public class Gemini15ProModel(VertexAIProvider provider) + : VertexAIChatModel(provider, "gemini-1.5-pro"); + + public class Gemini15FlashModel(VertexAIProvider provider) + : VertexAIChatModel(provider, "gemini-1.5-flash"); + + public class Gemini1ProModel(VertexAIProvider provider) + : VertexAIChatModel(provider, "gemini-1.0-pro"); +} diff --git a/src/Providers/Google.VertexAI/src/VertexAIChatModel.cs b/src/Providers/Google.VertexAI/src/VertexAIChatModel.cs new file mode 100644 index 00000000..97dbe70d --- /dev/null +++ b/src/Providers/Google.VertexAI/src/VertexAIChatModel.cs @@ -0,0 +1,79 @@ +using Google.Cloud.AIPlatform.V1; +using Google.Protobuf.Collections; +using System.Diagnostics; + +namespace LangChain.Providers.Google.VertexAI +{ + public class VertexAIChatModel( + VertexAIProvider provider, + string id + ) : ChatModel(id), IChatModel + { + private VertexAIProvider Provider { get; } = provider ?? throw new ArgumentNullException(nameof(provider)); + public override async Task GenerateAsync(ChatRequest request, + ChatSettings? settings = null, + CancellationToken cancellationToken = default) + { + + request = request ?? throw new ArgumentNullException(nameof(request)); + var prompt = ToPrompt(request.Messages); + + var watch = Stopwatch.StartNew(); + var response = await Provider.Api.GenerateContentAsync(prompt).ConfigureAwait(false); + + var result = request.Messages.ToList(); + result.Add(response.Candidates[0].Content.Parts[0].Text.AsAiMessage()); + + var usage = Usage.Empty with + { + Time = watch.Elapsed, + InputTokens = response.UsageMetadata.PromptTokenCount, + OutputTokens = response.UsageMetadata.CandidatesTokenCount + }; + + return new ChatResponse + { + Messages = result, + Usage = usage, + UsedSettings = ChatSettings.Default, + }; + + } + + private GenerateContentRequest ToPrompt(IEnumerable messages) + { + var contents = new RepeatedField(); + foreach (var message in messages) + { + contents.Add(ConvertMessage(message)); + } + + return new GenerateContentRequest + { + Model = $"projects/{provider.Configuration.ProjectId}/locations/{provider.Configuration.Location}/publishers/{provider.Configuration.Publisher}/models/{Id}", + Contents = { contents }, + GenerationConfig = provider.Configuration.GenerationConfig + }; + } + + private static Content ConvertMessage(Message message) + { + return new Content + { + Role = ConvertRole(message.Role), + Parts = { new Part { Text = message.Content } } + }; + } + + private static string ConvertRole(MessageRole role) + { + return role switch + { + MessageRole.Human => "USER", + MessageRole.Ai => "MODEL", + MessageRole.System => "SYSTEM", + _ => throw new NotSupportedException($"the role {role} is not supported") + }; + } + } +} diff --git a/src/Providers/Google.VertexAI/src/VertexAIConfiguration.cs b/src/Providers/Google.VertexAI/src/VertexAIConfiguration.cs new file mode 100644 index 00000000..ed4fd2e9 --- /dev/null +++ b/src/Providers/Google.VertexAI/src/VertexAIConfiguration.cs @@ -0,0 +1,13 @@ +using Google.Cloud.AIPlatform.V1; + +namespace LangChain.Providers.Google.VertexAI +{ + public class VertexAIConfiguration + { + public const string SectionName = "VertexAI"; + public string Location { get; set; } = "us-central1"; + public string Publisher { get; set; } = "google"; + public required string ProjectId { get; set; } + public GenerationConfig? GenerationConfig { get; set; } + } +} \ No newline at end of file diff --git a/src/Providers/Google.VertexAI/src/VertexAIProvider.cs b/src/Providers/Google.VertexAI/src/VertexAIProvider.cs new file mode 100644 index 00000000..63ea7bf0 --- /dev/null +++ b/src/Providers/Google.VertexAI/src/VertexAIProvider.cs @@ -0,0 +1,19 @@ +using Google.Cloud.AIPlatform.V1; + +namespace LangChain.Providers.Google.VertexAI +{ + + public class VertexAIProvider : Provider + { + public PredictionServiceClient Api { get; private set; } + public VertexAIConfiguration Configuration { get; private set; } + public VertexAIProvider(VertexAIConfiguration configuration) : base(id: VertexAIConfiguration.SectionName) + { + Configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); + Api = new PredictionServiceClientBuilder + { + Endpoint = $"{Configuration.Location}-aiplatform.googleapis.com" + }.Build(); + } + } +} diff --git a/src/Providers/Google.VertexAI/test/LangChain.Providers.Google.VertexAI.Tests.csproj b/src/Providers/Google.VertexAI/test/LangChain.Providers.Google.VertexAI.Tests.csproj new file mode 100644 index 00000000..b5105c90 --- /dev/null +++ b/src/Providers/Google.VertexAI/test/LangChain.Providers.Google.VertexAI.Tests.csproj @@ -0,0 +1,11 @@ + + + + net8.0 + + + + + + + diff --git a/src/Providers/Google.VertexAI/test/VertexAITest.cs b/src/Providers/Google.VertexAI/test/VertexAITest.cs new file mode 100644 index 00000000..0ad51907 --- /dev/null +++ b/src/Providers/Google.VertexAI/test/VertexAITest.cs @@ -0,0 +1,47 @@ +using Google.Apis.Auth.OAuth2; +using LangChain.Providers.Google.VertexAI.Predefined; + +namespace LangChain.Providers.Google.VertexAI.Test +{ + [TestFixture] + [Explicit] + public partial class VertexAITests + { + [Test] + public async Task Chat() + { + + //Required 'GOOGLE_APPLICATION_CREDENTIALS' env with Google credentials path json file. + + var credentials = GoogleCredential.GetApplicationDefault(); + + if (credentials.UnderlyingCredential is ServiceAccountCredential serviceAccountCredential) + { + + var config = new VertexAIConfiguration() + { + ProjectId = serviceAccountCredential.ProjectId, + //Publisher = "google", + //Location = "us-central1", + /*GenerationConfig = new GenerationConfig + { + Temperature = 0.4f, + TopP = 1, + TopK = 32, + MaxOutputTokens = 2048 + }*/ + }; + + var provider = new VertexAIProvider(config); + var model = new Gemini15ProModel(provider); + + string answer = await model.GenerateAsync("Generate some random name:"); + + answer.Should().NotBeNull(); + + Console.WriteLine(answer); + } + + } + } +} \ No newline at end of file