-
-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Google VertexAI provider chat #377
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFramework>netstandard2.0</TargetFramework> | ||
<NoWarn>$(NoWarn);CS3003</NoWarn> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Google.Cloud.AIPlatform.V1" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\..\Abstractions\src\LangChain.Providers.Abstractions.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
namespace LangChain.Providers.Google.VertexAI.Predefined | ||
{ | ||
public class Gemini15ProModel(VertexAIProvider provider) | ||
: VertexAIChatModel(provider, "gemini-1.5-pro"); | ||
|
||
public class Gemini15FlashModel(VertexAIProvider provider) | ||
: VertexAIChatModel(provider, "gemini-1.5-flash"); | ||
|
||
public class Gemini1ProModel(VertexAIProvider provider) | ||
: VertexAIChatModel(provider, "gemini-1.0-pro"); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
using Google.Cloud.AIPlatform.V1; | ||
using Google.Protobuf.Collections; | ||
using System.Diagnostics; | ||
|
||
namespace LangChain.Providers.Google.VertexAI | ||
{ | ||
public class VertexAIChatModel( | ||
VertexAIProvider provider, | ||
string id | ||
) : ChatModel(id), IChatModel | ||
{ | ||
private VertexAIProvider Provider { get; } = provider ?? throw new ArgumentNullException(nameof(provider)); | ||
public override async Task<ChatResponse> GenerateAsync(ChatRequest request, | ||
ChatSettings? settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
|
||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
var prompt = ToPrompt(request.Messages); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
var response = await Provider.Api.GenerateContentAsync(prompt).ConfigureAwait(false); | ||
|
||
var result = request.Messages.ToList(); | ||
result.Add(response.Candidates[0].Content.Parts[0].Text.AsAiMessage()); | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
InputTokens = response.UsageMetadata.PromptTokenCount, | ||
OutputTokens = response.UsageMetadata.CandidatesTokenCount | ||
}; | ||
|
||
return new ChatResponse | ||
{ | ||
Messages = result, | ||
Usage = usage, | ||
UsedSettings = ChatSettings.Default, | ||
}; | ||
|
||
} | ||
|
||
private GenerateContentRequest ToPrompt(IEnumerable<Message> messages) | ||
{ | ||
var contents = new RepeatedField<Content>(); | ||
foreach (var message in messages) | ||
{ | ||
contents.Add(ConvertMessage(message)); | ||
} | ||
|
||
return new GenerateContentRequest | ||
{ | ||
Model = $"projects/{provider.Configuration.ProjectId}/locations/{provider.Configuration.Location}/publishers/{provider.Configuration.Publisher}/models/{Id}", | ||
Contents = { contents }, | ||
GenerationConfig = provider.Configuration.GenerationConfig | ||
}; | ||
} | ||
Comment on lines
+43
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify The - var contents = new RepeatedField<Content>();
- foreach (var message in messages)
- {
- contents.Add(ConvertMessage(message));
- }
+ var contents = new RepeatedField<Content>(messages.Select(ConvertMessage));
|
||
|
||
private static Content ConvertMessage(Message message) | ||
{ | ||
return new Content | ||
{ | ||
Role = ConvertRole(message.Role), | ||
Parts = { new Part { Text = message.Content } } | ||
}; | ||
} | ||
|
||
private static string ConvertRole(MessageRole role) | ||
{ | ||
return role switch | ||
{ | ||
MessageRole.Human => "USER", | ||
MessageRole.Ai => "MODEL", | ||
MessageRole.System => "SYSTEM", | ||
_ => throw new NotSupportedException($"the role {role} is not supported") | ||
}; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
using Google.Cloud.AIPlatform.V1; | ||
|
||
namespace LangChain.Providers.Google.VertexAI | ||
{ | ||
public class VertexAIConfiguration | ||
{ | ||
public const string SectionName = "VertexAI"; | ||
public string Location { get; set; } = "us-central1"; | ||
public string Publisher { get; set; } = "google"; | ||
public required string ProjectId { get; set; } | ||
public GenerationConfig? GenerationConfig { get; set; } | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
using Google.Cloud.AIPlatform.V1; | ||
|
||
namespace LangChain.Providers.Google.VertexAI | ||
{ | ||
|
||
public class VertexAIProvider : Provider | ||
{ | ||
public PredictionServiceClient Api { get; private set; } | ||
public VertexAIConfiguration Configuration { get; private set; } | ||
public VertexAIProvider(VertexAIConfiguration configuration) : base(id: VertexAIConfiguration.SectionName) | ||
{ | ||
Configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); | ||
Api = new PredictionServiceClientBuilder | ||
{ | ||
Endpoint = $"{Configuration.Location}-aiplatform.googleapis.com" | ||
}.Build(); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFramework>net8.0</TargetFramework> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\src\LangChain.Providers.Google.VertexAI.csproj"/> | ||
</ItemGroup> | ||
|
||
</Project> |
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,47 @@ | ||||||||||||||||||
using Google.Apis.Auth.OAuth2; | ||||||||||||||||||
using LangChain.Providers.Google.VertexAI.Predefined; | ||||||||||||||||||
|
||||||||||||||||||
namespace LangChain.Providers.Google.VertexAI.Test | ||||||||||||||||||
{ | ||||||||||||||||||
[TestFixture] | ||||||||||||||||||
[Explicit] | ||||||||||||||||||
public partial class VertexAITests | ||||||||||||||||||
{ | ||||||||||||||||||
[Test] | ||||||||||||||||||
public async Task Chat() | ||||||||||||||||||
{ | ||||||||||||||||||
|
||||||||||||||||||
//Required 'GOOGLE_APPLICATION_CREDENTIALS' env with Google credentials path json file. | ||||||||||||||||||
|
||||||||||||||||||
var credentials = GoogleCredential.GetApplicationDefault(); | ||||||||||||||||||
|
||||||||||||||||||
if (credentials.UnderlyingCredential is ServiceAccountCredential serviceAccountCredential) | ||||||||||||||||||
{ | ||||||||||||||||||
|
||||||||||||||||||
var config = new VertexAIConfiguration() | ||||||||||||||||||
{ | ||||||||||||||||||
ProjectId = serviceAccountCredential.ProjectId, | ||||||||||||||||||
//Publisher = "google", | ||||||||||||||||||
//Location = "us-central1", | ||||||||||||||||||
/*GenerationConfig = new GenerationConfig | ||||||||||||||||||
{ | ||||||||||||||||||
Temperature = 0.4f, | ||||||||||||||||||
TopP = 1, | ||||||||||||||||||
TopK = 32, | ||||||||||||||||||
MaxOutputTokens = 2048 | ||||||||||||||||||
}*/ | ||||||||||||||||||
}; | ||||||||||||||||||
|
||||||||||||||||||
var provider = new VertexAIProvider(config); | ||||||||||||||||||
var model = new Gemini15ProModel(provider); | ||||||||||||||||||
|
||||||||||||||||||
string answer = await model.GenerateAsync("Generate some random name:"); | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding error handling. Add error handling to manage potential exceptions from the - string answer = await model.GenerateAsync("Generate some random name:");
+ string answer;
+ try {
+ answer = await model.GenerateAsync("Generate some random name:");
+ } catch (Exception ex) {
+ Console.WriteLine($"Error generating response: {ex.Message}");
+ throw;
+ } Committable suggestion
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
answer.Should().NotBeNull(); | ||||||||||||||||||
|
||||||||||||||||||
Console.WriteLine(answer); | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use provided
settings
parameter instead of hardcodingChatSettings.Default
.The
UsedSettings
property of theChatResponse
object is hardcoded toChatSettings.Default
. It should use thesettings
parameter if provided.Committable suggestion
Remove unnecessary re-assignment of
request
.The
request
parameter is re-assigned after the null-check, which is unnecessary.- request = request ?? throw new ArgumentNullException(nameof(request));
Committable suggestion