Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Google VertexAI provider chat #377

Merged
merged 1 commit into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions LangChain.sln
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,12 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Groq", "Groq", "{5DEC2707-D
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Groq.Tests", "src\Providers\Groq\test\LangChain.Providers.Groq.Tests.csproj", "{CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E}"
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "VertexAI", "VertexAI", "{1F6C3FD4-959B-4A6E-ACCA-C0466149EC05}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.VertexAI", "src\Providers\Google.VertexAI\src\LangChain.Providers.Google.VertexAI.csproj", "{F5B1B04A-F72B-44D2-8B25-BB3717146D00}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LangChain.Providers.Google.VertexAI.Tests", "src\Providers\Google.VertexAI\test\LangChain.Providers.Google.VertexAI.Tests.csproj", "{EFD4C813-47A1-4B40-8A23-D660B82B8938}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -734,6 +740,14 @@ Global
{CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E}.Release|Any CPU.Build.0 = Release|Any CPU
{F5B1B04A-F72B-44D2-8B25-BB3717146D00}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F5B1B04A-F72B-44D2-8B25-BB3717146D00}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F5B1B04A-F72B-44D2-8B25-BB3717146D00}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F5B1B04A-F72B-44D2-8B25-BB3717146D00}.Release|Any CPU.Build.0 = Release|Any CPU
{EFD4C813-47A1-4B40-8A23-D660B82B8938}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{EFD4C813-47A1-4B40-8A23-D660B82B8938}.Debug|Any CPU.Build.0 = Debug|Any CPU
{EFD4C813-47A1-4B40-8A23-D660B82B8938}.Release|Any CPU.ActiveCfg = Release|Any CPU
{EFD4C813-47A1-4B40-8A23-D660B82B8938}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -878,6 +892,9 @@ Global
{FD0A56AD-AFB4-4A21-99C1-9BE5D074EB56} = {5DEC2707-DD62-4DCD-AC5D-6670AC2A1B01}
{5DEC2707-DD62-4DCD-AC5D-6670AC2A1B01} = {E2B9833C-0397-4FAF-A3A8-116E58749750}
{CC7F58F4-C824-4BED-8C4A-760C9AB8FC6E} = {5DEC2707-DD62-4DCD-AC5D-6670AC2A1B01}
{1F6C3FD4-959B-4A6E-ACCA-C0466149EC05} = {A23BD019-BE70-42D5-9DD4-A79DEDDE54F4}
{F5B1B04A-F72B-44D2-8B25-BB3717146D00} = {1F6C3FD4-959B-4A6E-ACCA-C0466149EC05}
{EFD4C813-47A1-4B40-8A23-D660B82B8938} = {1F6C3FD4-959B-4A6E-ACCA-C0466149EC05}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {5C00D0F1-6138-4ED9-846B-97E43D6DFF1C}
Expand Down
1 change: 1 addition & 0 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
<PackageVersion Include="DotNet.ReproducibleBuilds" Version="1.2.4" />
<PackageVersion Include="FluentAssertions" Version="6.12.0" />
<PackageVersion Include="GitHubActionsTestLogger" Version="2.4.1" />
<PackageVersion Include="Google.Cloud.AIPlatform.V1" Version="3.4.0" />
<PackageVersion Include="Google_GenerativeAI" Version="1.0.1" />
<PackageVersion Include="GroqSharp" Version="1.1.2" />
<PackageVersion Include="H.Generators.Extensions" Version="1.22.0" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<NoWarn>$(NoWarn);CS3003</NoWarn>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Google.Cloud.AIPlatform.V1" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\Abstractions\src\LangChain.Providers.Abstractions.csproj" />
</ItemGroup>

</Project>
11 changes: 11 additions & 0 deletions src/Providers/Google.VertexAI/src/Predefined/GeminiModels.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace LangChain.Providers.Google.VertexAI.Predefined
{
public class Gemini15ProModel(VertexAIProvider provider)
: VertexAIChatModel(provider, "gemini-1.5-pro");

public class Gemini15FlashModel(VertexAIProvider provider)
: VertexAIChatModel(provider, "gemini-1.5-flash");

public class Gemini1ProModel(VertexAIProvider provider)
: VertexAIChatModel(provider, "gemini-1.0-pro");
}
79 changes: 79 additions & 0 deletions src/Providers/Google.VertexAI/src/VertexAIChatModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using Google.Cloud.AIPlatform.V1;
using Google.Protobuf.Collections;
using System.Diagnostics;

namespace LangChain.Providers.Google.VertexAI
{
public class VertexAIChatModel(
VertexAIProvider provider,
string id
) : ChatModel(id), IChatModel
{
private VertexAIProvider Provider { get; } = provider ?? throw new ArgumentNullException(nameof(provider));
public override async Task<ChatResponse> GenerateAsync(ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{

request = request ?? throw new ArgumentNullException(nameof(request));
var prompt = ToPrompt(request.Messages);

var watch = Stopwatch.StartNew();
var response = await Provider.Api.GenerateContentAsync(prompt).ConfigureAwait(false);

var result = request.Messages.ToList();
result.Add(response.Candidates[0].Content.Parts[0].Text.AsAiMessage());

var usage = Usage.Empty with
{
Time = watch.Elapsed,
InputTokens = response.UsageMetadata.PromptTokenCount,
OutputTokens = response.UsageMetadata.CandidatesTokenCount
};

return new ChatResponse
{
Messages = result,
Usage = usage,
UsedSettings = ChatSettings.Default,
};

}
Comment on lines +13 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use provided settings parameter instead of hardcoding ChatSettings.Default.

The UsedSettings property of the ChatResponse object is hardcoded to ChatSettings.Default. It should use the settings parameter if provided.

-  UsedSettings = ChatSettings.Default,
+  UsedSettings = settings ?? ChatSettings.Default,
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
public override async Task<ChatResponse> GenerateAsync(ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
var prompt = ToPrompt(request.Messages);
var watch = Stopwatch.StartNew();
var response = await Provider.Api.GenerateContentAsync(prompt).ConfigureAwait(false);
var result = request.Messages.ToList();
result.Add(response.Candidates[0].Content.Parts[0].Text.AsAiMessage());
var usage = Usage.Empty with
{
Time = watch.Elapsed,
InputTokens = response.UsageMetadata.PromptTokenCount,
OutputTokens = response.UsageMetadata.CandidatesTokenCount
};
return new ChatResponse
{
Messages = result,
Usage = usage,
UsedSettings = ChatSettings.Default,
};
}
public override async Task<ChatResponse> GenerateAsync(ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
var prompt = ToPrompt(request.Messages);
var watch = Stopwatch.StartNew();
var response = await Provider.Api.GenerateContentAsync(prompt).ConfigureAwait(false);
var result = request.Messages.ToList();
result.Add(response.Candidates[0].Content.Parts[0].Text.AsAiMessage());
var usage = Usage.Empty with
{
Time = watch.Elapsed,
InputTokens = response.UsageMetadata.PromptTokenCount,
OutputTokens = response.UsageMetadata.CandidatesTokenCount
};
return new ChatResponse
{
Messages = result,
Usage = usage,
UsedSettings = settings ?? ChatSettings.Default,
};
}

Remove unnecessary re-assignment of request.

The request parameter is re-assigned after the null-check, which is unnecessary.

-  request = request ?? throw new ArgumentNullException(nameof(request));
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
public override async Task<ChatResponse> GenerateAsync(ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
var prompt = ToPrompt(request.Messages);
var watch = Stopwatch.StartNew();
var response = await Provider.Api.GenerateContentAsync(prompt).ConfigureAwait(false);
var result = request.Messages.ToList();
result.Add(response.Candidates[0].Content.Parts[0].Text.AsAiMessage());
var usage = Usage.Empty with
{
Time = watch.Elapsed,
InputTokens = response.UsageMetadata.PromptTokenCount,
OutputTokens = response.UsageMetadata.CandidatesTokenCount
};
return new ChatResponse
{
Messages = result,
Usage = usage,
UsedSettings = ChatSettings.Default,
};
}
public override async Task<ChatResponse> GenerateAsync(ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
var prompt = ToPrompt(request.Messages);
var watch = Stopwatch.StartNew();
var response = await Provider.Api.GenerateContentAsync(prompt).ConfigureAwait(false);
var result = request.Messages.ToList();
result.Add(response.Candidates[0].Content.Parts[0].Text.AsAiMessage());
var usage = Usage.Empty with
{
Time = watch.Elapsed,
InputTokens = response.UsageMetadata.PromptTokenCount,
OutputTokens = response.UsageMetadata.CandidatesTokenCount
};
return new ChatResponse
{
Messages = result,
Usage = usage,
UsedSettings = ChatSettings.Default,
};
}


private GenerateContentRequest ToPrompt(IEnumerable<Message> messages)
{
var contents = new RepeatedField<Content>();
foreach (var message in messages)
{
contents.Add(ConvertMessage(message));
}

return new GenerateContentRequest
{
Model = $"projects/{provider.Configuration.ProjectId}/locations/{provider.Configuration.Location}/publishers/{provider.Configuration.Publisher}/models/{Id}",
Contents = { contents },
GenerationConfig = provider.Configuration.GenerationConfig
};
}
Comment on lines +43 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify contents initialization using LINQ.

The contents variable declaration and initialization can be simplified using LINQ.

-  var contents = new RepeatedField<Content>();
-  foreach (var message in messages)
-  {
-    contents.Add(ConvertMessage(message));
-  }
+  var contents = new RepeatedField<Content>(messages.Select(ConvertMessage));

Committable suggestion was skipped due to low confidence.


private static Content ConvertMessage(Message message)
{
return new Content
{
Role = ConvertRole(message.Role),
Parts = { new Part { Text = message.Content } }
};
}

private static string ConvertRole(MessageRole role)
{
return role switch
{
MessageRole.Human => "USER",
MessageRole.Ai => "MODEL",
MessageRole.System => "SYSTEM",
_ => throw new NotSupportedException($"the role {role} is not supported")
};
}
}
}
13 changes: 13 additions & 0 deletions src/Providers/Google.VertexAI/src/VertexAIConfiguration.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using Google.Cloud.AIPlatform.V1;

namespace LangChain.Providers.Google.VertexAI
{
public class VertexAIConfiguration
{
public const string SectionName = "VertexAI";
public string Location { get; set; } = "us-central1";
public string Publisher { get; set; } = "google";
public required string ProjectId { get; set; }
public GenerationConfig? GenerationConfig { get; set; }
}
}
19 changes: 19 additions & 0 deletions src/Providers/Google.VertexAI/src/VertexAIProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using Google.Cloud.AIPlatform.V1;

namespace LangChain.Providers.Google.VertexAI
{

public class VertexAIProvider : Provider
{
public PredictionServiceClient Api { get; private set; }
public VertexAIConfiguration Configuration { get; private set; }
public VertexAIProvider(VertexAIConfiguration configuration) : base(id: VertexAIConfiguration.SectionName)
{
Configuration = configuration ?? throw new ArgumentNullException(nameof(configuration));
Api = new PredictionServiceClientBuilder
{
Endpoint = $"{Configuration.Location}-aiplatform.googleapis.com"
}.Build();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\src\LangChain.Providers.Google.VertexAI.csproj"/>
</ItemGroup>

</Project>
47 changes: 47 additions & 0 deletions src/Providers/Google.VertexAI/test/VertexAITest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using Google.Apis.Auth.OAuth2;
using LangChain.Providers.Google.VertexAI.Predefined;

namespace LangChain.Providers.Google.VertexAI.Test
{
[TestFixture]
[Explicit]
public partial class VertexAITests
{
[Test]
public async Task Chat()
{

//Required 'GOOGLE_APPLICATION_CREDENTIALS' env with Google credentials path json file.

var credentials = GoogleCredential.GetApplicationDefault();

if (credentials.UnderlyingCredential is ServiceAccountCredential serviceAccountCredential)
{

var config = new VertexAIConfiguration()
{
ProjectId = serviceAccountCredential.ProjectId,
//Publisher = "google",
//Location = "us-central1",
/*GenerationConfig = new GenerationConfig
{
Temperature = 0.4f,
TopP = 1,
TopK = 32,
MaxOutputTokens = 2048
}*/
};

var provider = new VertexAIProvider(config);
var model = new Gemini15ProModel(provider);

string answer = await model.GenerateAsync("Generate some random name:");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding error handling.

Add error handling to manage potential exceptions from the GenerateAsync method.

-  string answer = await model.GenerateAsync("Generate some random name:");
+  string answer;
+  try {
+    answer = await model.GenerateAsync("Generate some random name:");
+  } catch (Exception ex) {
+    Console.WriteLine($"Error generating response: {ex.Message}");
+    throw;
+  }
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
string answer = await model.GenerateAsync("Generate some random name:");
string answer;
try {
answer = await model.GenerateAsync("Generate some random name:");
} catch (Exception ex) {
Console.WriteLine($"Error generating response: {ex.Message}");
throw;
}


answer.Should().NotBeNull();

Console.WriteLine(answer);
}

}
}
}
Loading