-
-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: added Amazon Bedrock Cohere R plus and Cohere R models #285
Conversation
fix: fixed small issues
WalkthroughThe updates across various files in the Amazon Bedrock module focus on enhancing chat and embedding functionalities, adjusting settings for maximal tokens, and refining model identifiers. New classes and methods are introduced to support asynchronous operations and JSON body creation, improving interaction with Cohere and Amazon Titan models. Changes
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Out of diff range and nitpick comments (3)
src/Providers/Amazon.Bedrock/src/Embedding/Settings/AmazonV2EmbeddingSettings.cs (1)
4-55
: IntroducedAmazonV2EmbeddingSettings
with default settings and dynamic calculation capabilities.Consider adding detailed comments explaining the use of the
new
keyword in static members to avoid confusion.src/Providers/Amazon.Bedrock/src/Chat/Settings/MetaLlama2ChatSettings.cs (1)
Line range hint
1-12
: IntroducedMetaLlama2ChatSettings
with default settings and dynamic calculation capabilities.Consider adding detailed comments explaining the use of the
new
keyword in static members to avoid confusion.src/Providers/Amazon.Bedrock/test/BedrockTextModelTests.cs (1)
[!TIP]
Codebase VerificationThe verification process has revealed that the transition to the new instantiation method using
RegionEndpoint.USWest2
inBedrockProvider
is incomplete. Several files still use the old instantiation method without theRegionEndpoint
parameter. Here are the files where issues were found:
src/Providers/Amazon.Bedrock/test/BedrockTextModelTests.cs
src/Providers/Amazon.Bedrock/test/BedrockTests.cs
src/Meta/test/CrewTests.cs
src/Providers/Amazon.Sagemaker/test/SageMakerTests.cs
src/Databases/IntegrationTests/DatabaseTests.OpenSearch.cs
Analysis chain
Line range hint
24-61
: Updated test setup inBedrockTextModelTests
to includeRegionEndpoint
parameter, enhancing specificity in AWS region selection.Scripts Executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that the new `RegionEndpoint` parameter is correctly used in all relevant test setups. # Search for the old instantiation without the `RegionEndpoint` to ensure it's no longer used. rg --files-with-matches --type cs 'new BedrockProvider()' | while read file; do # Expect: No matches. echo "Old instantiation found in $file" done # Search for the new instantiation to validate usage. rg --files-with-matches --type cs 'new BedrockProvider(RegionEndpoint.USWest2)' | while read file; do # Expect: Matches. echo "New instantiation used in $file" doneLength of output: 1038
public async Task<EmbeddingResponse> CreateEmbeddingsAsync( | ||
EmbeddingRequest request, | ||
EmbeddingSettings? settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
|
||
var usedSettings = AmazonV2EmbeddingSettings.Calculate( | ||
requestSettings: settings, | ||
modelSettings: Settings, | ||
providerSettings: provider.EmbeddingSettings); | ||
|
||
var embeddings = new List<float[]>(capacity: request.Strings.Count); | ||
|
||
var tasks = request.Strings.Select(text => | ||
{ | ||
var bodyJson = CreateBodyJson(text, usedSettings); | ||
return provider.Api.InvokeModelAsync(Id, bodyJson, | ||
cancellationToken); | ||
}) | ||
.ToList(); | ||
var results = await Task.WhenAll(tasks).ConfigureAwait(false); | ||
|
||
foreach (var response in results) | ||
{ | ||
var embedding = response?["embedding"]?.AsArray(); | ||
if (embedding == null) continue; | ||
|
||
var f = new float[(int)usedSettings.Dimensions!]; | ||
for (var i = 0; i < embedding.Count; i++) | ||
{ | ||
f[i] = (float)embedding[(Index)i]?.AsValue()!; | ||
} | ||
|
||
embeddings.Add(f); | ||
} | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
AddUsage(usage); | ||
provider.AddUsage(usage); | ||
|
||
return new EmbeddingResponse | ||
{ | ||
Values = embeddings.ToArray(), | ||
Usage = Usage.Empty, | ||
UsedSettings = usedSettings, | ||
Dimensions = embeddings.FirstOrDefault()?.Length ?? 0, | ||
}; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented CreateEmbeddingsAsync
in AmazonTitanEmbeddingV2Model
with comprehensive error handling and efficient asynchronous operations.
Optimize JSON body construction by directly using JsonObject
properties instead of intermediate conversions.
- var bodyJson = new JsonObject
- {
- ["inputText"] = prompt,
- ["dimensions"] = usedSettings.Dimensions,
- ["normalize"] = usedSettings.Normalize
- };
+ var bodyJson = new JsonObject
+ {
+ { "inputText", prompt },
+ { "dimensions", usedSettings.Dimensions },
+ { "normalize", usedSettings.Normalize }
+ };
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
public async Task<EmbeddingResponse> CreateEmbeddingsAsync( | |
EmbeddingRequest request, | |
EmbeddingSettings? settings = null, | |
CancellationToken cancellationToken = default) | |
{ | |
request = request ?? throw new ArgumentNullException(nameof(request)); | |
var watch = Stopwatch.StartNew(); | |
var usedSettings = AmazonV2EmbeddingSettings.Calculate( | |
requestSettings: settings, | |
modelSettings: Settings, | |
providerSettings: provider.EmbeddingSettings); | |
var embeddings = new List<float[]>(capacity: request.Strings.Count); | |
var tasks = request.Strings.Select(text => | |
{ | |
var bodyJson = CreateBodyJson(text, usedSettings); | |
return provider.Api.InvokeModelAsync(Id, bodyJson, | |
cancellationToken); | |
}) | |
.ToList(); | |
var results = await Task.WhenAll(tasks).ConfigureAwait(false); | |
foreach (var response in results) | |
{ | |
var embedding = response?["embedding"]?.AsArray(); | |
if (embedding == null) continue; | |
var f = new float[(int)usedSettings.Dimensions!]; | |
for (var i = 0; i < embedding.Count; i++) | |
{ | |
f[i] = (float)embedding[(Index)i]?.AsValue()!; | |
} | |
embeddings.Add(f); | |
} | |
var usage = Usage.Empty with | |
{ | |
Time = watch.Elapsed, | |
}; | |
AddUsage(usage); | |
provider.AddUsage(usage); | |
return new EmbeddingResponse | |
{ | |
Values = embeddings.ToArray(), | |
Usage = Usage.Empty, | |
UsedSettings = usedSettings, | |
Dimensions = embeddings.FirstOrDefault()?.Length ?? 0, | |
}; | |
} | |
public async Task<EmbeddingResponse> CreateEmbeddingsAsync( | |
EmbeddingRequest request, | |
EmbeddingSettings? settings = null, | |
CancellationToken cancellationToken = default) | |
{ | |
request = request ?? throw new ArgumentNullException(nameof(request)); | |
var watch = Stopwatch.StartNew(); | |
var usedSettings = AmazonV2EmbeddingSettings.Calculate( | |
requestSettings: settings, | |
modelSettings: Settings, | |
providerSettings: provider.EmbeddingSettings); | |
var embeddings = new List<float[]>(capacity: request.Strings.Count); | |
var tasks = request.Strings.Select(text => | |
{ | |
var bodyJson = new JsonObject | |
{ | |
{ "inputText", text }, | |
{ "dimensions", usedSettings.Dimensions }, | |
{ "normalize", usedSettings.Normalize } | |
}; | |
return provider.Api.InvokeModelAsync(Id, bodyJson, | |
cancellationToken); | |
}) | |
.ToList(); | |
var results = await Task.WhenAll(tasks).ConfigureAwait(false); | |
foreach (var response in results) | |
{ | |
var embedding = response?["embedding"]?.AsArray(); | |
if (embedding == null) continue; | |
var f = new float[(int)usedSettings.Dimensions!]; | |
for (var i = 0; i < embedding.Count; i++) | |
{ | |
f[i] = (float)embedding[(Index)i]?.AsValue()!; | |
} | |
embeddings.Add(f); | |
} | |
var usage = Usage.Empty with | |
{ | |
Time = watch.Elapsed, | |
}; | |
AddUsage(usage); | |
provider.AddUsage(usage); | |
return new EmbeddingResponse | |
{ | |
Values = embeddings.ToArray(), | |
Usage = Usage.Empty, | |
UsedSettings = usedSettings, | |
Dimensions = embeddings.FirstOrDefault()?.Length ?? 0, | |
}; | |
} |
public override async Task<ChatResponse> GenerateAsync( | ||
ChatRequest request, | ||
ChatSettings? settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
var prompt = request.Messages.ToSimplePrompt(); | ||
var messages = request.Messages.ToList(); | ||
|
||
var stringBuilder = new StringBuilder(); | ||
|
||
var usedSettings = CohereCommandChatSettings.Calculate( | ||
requestSettings: settings, | ||
modelSettings: Settings, | ||
providerSettings: provider.ChatSettings); | ||
|
||
var bodyJson = CreateBodyJson(prompt, usedSettings); | ||
|
||
if (usedSettings.UseStreaming == true) | ||
{ | ||
var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson); | ||
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false); | ||
|
||
foreach (var payloadPart in response.Body) | ||
{ | ||
var streamEvent = (PayloadPart)payloadPart; | ||
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken) | ||
.ConfigureAwait(false); | ||
var delta = chunk?["text"]?.GetValue<string>() ?? string.Empty; | ||
|
||
OnPartialResponseGenerated(delta); | ||
stringBuilder.Append(delta); | ||
|
||
var finished = chunk?["finish_reason"]?.GetValue<string>() ?? string.Empty; | ||
if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal)) | ||
{ | ||
OnCompletedResponseGenerated(stringBuilder.ToString()); | ||
} | ||
} | ||
} | ||
else | ||
{ | ||
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
var generatedText = response?["text"]?.GetValue<string>() ?? string.Empty; | ||
|
||
messages.Add(generatedText.AsAiMessage()); | ||
OnCompletedResponseGenerated(generatedText); | ||
} | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
AddUsage(usage); | ||
provider.AddUsage(usage); | ||
|
||
return new ChatResponse | ||
{ | ||
Messages = messages, | ||
UsedSettings = usedSettings, | ||
Usage = usage, | ||
}; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented GenerateAsync
in CohereCommandRModel
with comprehensive error handling and efficient asynchronous operations.
Optimize JSON body construction by directly using JsonObject
properties instead of intermediate conversions.
- var bodyJson = new JsonObject
- {
- ["message"] = prompt,
- };
+ var bodyJson = new JsonObject
+ {
+ { "message", prompt },
+ };
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
public override async Task<ChatResponse> GenerateAsync( | |
ChatRequest request, | |
ChatSettings? settings = null, | |
CancellationToken cancellationToken = default) | |
{ | |
request = request ?? throw new ArgumentNullException(nameof(request)); | |
var watch = Stopwatch.StartNew(); | |
var prompt = request.Messages.ToSimplePrompt(); | |
var messages = request.Messages.ToList(); | |
var stringBuilder = new StringBuilder(); | |
var usedSettings = CohereCommandChatSettings.Calculate( | |
requestSettings: settings, | |
modelSettings: Settings, | |
providerSettings: provider.ChatSettings); | |
var bodyJson = CreateBodyJson(prompt, usedSettings); | |
if (usedSettings.UseStreaming == true) | |
{ | |
var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson); | |
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false); | |
foreach (var payloadPart in response.Body) | |
{ | |
var streamEvent = (PayloadPart)payloadPart; | |
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken) | |
.ConfigureAwait(false); | |
var delta = chunk?["text"]?.GetValue<string>() ?? string.Empty; | |
OnPartialResponseGenerated(delta); | |
stringBuilder.Append(delta); | |
var finished = chunk?["finish_reason"]?.GetValue<string>() ?? string.Empty; | |
if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal)) | |
{ | |
OnCompletedResponseGenerated(stringBuilder.ToString()); | |
} | |
} | |
} | |
else | |
{ | |
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) | |
.ConfigureAwait(false); | |
var generatedText = response?["text"]?.GetValue<string>() ?? string.Empty; | |
messages.Add(generatedText.AsAiMessage()); | |
OnCompletedResponseGenerated(generatedText); | |
} | |
var usage = Usage.Empty with | |
{ | |
Time = watch.Elapsed, | |
}; | |
AddUsage(usage); | |
provider.AddUsage(usage); | |
return new ChatResponse | |
{ | |
Messages = messages, | |
UsedSettings = usedSettings, | |
Usage = usage, | |
}; | |
} | |
public override async Task<ChatResponse> GenerateAsync( | |
ChatRequest request, | |
ChatSettings? settings = null, | |
CancellationToken cancellationToken = default) | |
{ | |
request = request ?? throw new ArgumentNullException(nameof(request)); | |
var watch = Stopwatch.StartNew(); | |
var prompt = request.Messages.ToSimplePrompt(); | |
var messages = request.Messages.ToList(); | |
var stringBuilder = new StringBuilder(); | |
var usedSettings = CohereCommandChatSettings.Calculate( | |
requestSettings: settings, | |
modelSettings: Settings, | |
providerSettings: provider.ChatSettings); | |
var bodyJson = new JsonObject | |
{ | |
{ "message", prompt }, | |
}; | |
if (usedSettings.UseStreaming == true) | |
{ | |
var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson); | |
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false); | |
foreach (var payloadPart in response.Body) | |
{ | |
var streamEvent = (PayloadPart)payloadPart; | |
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken) | |
.ConfigureAwait(false); | |
var delta = chunk?["text"]?.GetValue<string>() ?? string.Empty; | |
OnPartialResponseGenerated(delta); | |
stringBuilder.Append(delta); | |
var finished = chunk?["finish_reason"]?.GetValue<string>() ?? string.Empty; | |
if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal)) | |
{ | |
OnCompletedResponseGenerated(stringBuilder.ToString()); | |
} | |
} | |
} | |
else | |
{ | |
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) | |
.ConfigureAwait(false); | |
var generatedText = response?["text"]?.GetValue<string>() ?? string.Empty; | |
messages.Add(generatedText.AsAiMessage()); | |
OnCompletedResponseGenerated(generatedText); | |
} | |
var usage = Usage.Empty with | |
{ | |
Time = watch.Elapsed, | |
}; | |
AddUsage(usage); | |
provider.AddUsage(usage); | |
return new ChatResponse | |
{ | |
Messages = messages, | |
UsedSettings = usedSettings, | |
Usage = usage, | |
}; | |
} |
fix: fixed small issues
Summary by CodeRabbit
New Features
Enhancements
Bug Fixes
Refactor