-
-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: added Amazon Bedrock Cohere R plus and Cohere R models #285
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
105 changes: 105 additions & 0 deletions
105
src/Providers/Amazon.Bedrock/src/Chat/CohereCommandRModel.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
using System.Diagnostics; | ||
using System.Text; | ||
using System.Text.Json; | ||
using System.Text.Json.Nodes; | ||
using Amazon.BedrockRuntime.Model; | ||
using LangChain.Providers.Amazon.Bedrock.Internal; | ||
|
||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock; | ||
|
||
public abstract class CohereCommandRModel( | ||
BedrockProvider provider, | ||
string id) | ||
: ChatModel(id) | ||
{ | ||
/// <summary> | ||
/// Generates a chat response based on the provided `ChatRequest`. | ||
/// </summary> | ||
/// <param name="request">The `ChatRequest` containing the input messages and other parameters.</param> | ||
/// <param name="settings">Optional `ChatSettings` to override the model's default settings.</param> | ||
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param> | ||
/// <returns>A `ChatResponse` containing the generated messages and usage information.</returns> | ||
public override async Task<ChatResponse> GenerateAsync( | ||
ChatRequest request, | ||
ChatSettings? settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
var prompt = request.Messages.ToSimplePrompt(); | ||
var messages = request.Messages.ToList(); | ||
|
||
var stringBuilder = new StringBuilder(); | ||
|
||
var usedSettings = CohereCommandChatSettings.Calculate( | ||
requestSettings: settings, | ||
modelSettings: Settings, | ||
providerSettings: provider.ChatSettings); | ||
|
||
var bodyJson = CreateBodyJson(prompt, usedSettings); | ||
|
||
if (usedSettings.UseStreaming == true) | ||
{ | ||
var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson); | ||
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false); | ||
|
||
foreach (var payloadPart in response.Body) | ||
{ | ||
var streamEvent = (PayloadPart)payloadPart; | ||
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken) | ||
.ConfigureAwait(false); | ||
var delta = chunk?["text"]?.GetValue<string>() ?? string.Empty; | ||
|
||
OnPartialResponseGenerated(delta); | ||
stringBuilder.Append(delta); | ||
|
||
var finished = chunk?["finish_reason"]?.GetValue<string>() ?? string.Empty; | ||
if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal)) | ||
{ | ||
OnCompletedResponseGenerated(stringBuilder.ToString()); | ||
} | ||
} | ||
} | ||
else | ||
{ | ||
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
var generatedText = response?["text"]?.GetValue<string>() ?? string.Empty; | ||
|
||
messages.Add(generatedText.AsAiMessage()); | ||
OnCompletedResponseGenerated(generatedText); | ||
} | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
AddUsage(usage); | ||
provider.AddUsage(usage); | ||
|
||
return new ChatResponse | ||
{ | ||
Messages = messages, | ||
UsedSettings = usedSettings, | ||
Usage = usage, | ||
}; | ||
} | ||
|
||
/// <summary> | ||
/// Creates the request body JSON for the Cohere model based on the provided prompt and settings. | ||
/// </summary> | ||
/// <param name="prompt">The input prompt for the model.</param> | ||
/// <param name="usedSettings">The settings to use for the request.</param> | ||
/// <returns>A `JsonObject` representing the request body.</returns> | ||
private static JsonObject CreateBodyJson(string prompt, CohereCommandChatSettings usedSettings) | ||
{ | ||
var bodyJson = new JsonObject | ||
{ | ||
["message"] = prompt, | ||
}; | ||
return bodyJson; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
88 changes: 88 additions & 0 deletions
88
src/Providers/Amazon.Bedrock/src/Embedding/AmazonTitanEmbeddingV2Model.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
using System.Diagnostics; | ||
using System.Text.Json.Nodes; | ||
using LangChain.Providers.Amazon.Bedrock.Internal; | ||
|
||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock; | ||
|
||
public abstract class AmazonTitanEmbeddingV2Model( | ||
BedrockProvider provider, | ||
string id) | ||
: Model<EmbeddingSettings>(id), IEmbeddingModel | ||
{ | ||
|
||
/// <summary> | ||
/// Creates embeddings for the input strings using the Amazon model. | ||
/// </summary> | ||
/// <param name="request">The `EmbeddingRequest` containing the input strings.</param> | ||
/// <param name="settings">Optional `EmbeddingSettings` to override the model's default settings.</param> | ||
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param> | ||
/// <returns>An `EmbeddingResponse` containing the generated embeddings and usage information.</returns> | ||
|
||
public async Task<EmbeddingResponse> CreateEmbeddingsAsync( | ||
EmbeddingRequest request, | ||
EmbeddingSettings? settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
|
||
var usedSettings = AmazonV2EmbeddingSettings.Calculate( | ||
requestSettings: settings, | ||
modelSettings: Settings, | ||
providerSettings: provider.EmbeddingSettings); | ||
|
||
var embeddings = new List<float[]>(capacity: request.Strings.Count); | ||
|
||
var tasks = request.Strings.Select(text => | ||
{ | ||
var bodyJson = CreateBodyJson(text, usedSettings); | ||
return provider.Api.InvokeModelAsync(Id, bodyJson, | ||
cancellationToken); | ||
}) | ||
.ToList(); | ||
var results = await Task.WhenAll(tasks).ConfigureAwait(false); | ||
|
||
foreach (var response in results) | ||
{ | ||
var embedding = response?["embedding"]?.AsArray(); | ||
if (embedding == null) continue; | ||
|
||
var f = new float[(int)usedSettings.Dimensions!]; | ||
for (var i = 0; i < embedding.Count; i++) | ||
{ | ||
f[i] = (float)embedding[(Index)i]?.AsValue()!; | ||
} | ||
|
||
embeddings.Add(f); | ||
} | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
AddUsage(usage); | ||
provider.AddUsage(usage); | ||
|
||
return new EmbeddingResponse | ||
{ | ||
Values = embeddings.ToArray(), | ||
Usage = Usage.Empty, | ||
UsedSettings = usedSettings, | ||
Dimensions = embeddings.FirstOrDefault()?.Length ?? 0, | ||
}; | ||
} | ||
|
||
private static JsonObject CreateBodyJson(string? prompt, AmazonV2EmbeddingSettings usedSettings) | ||
{ | ||
var bodyJson = new JsonObject | ||
{ | ||
["inputText"] = prompt, | ||
["dimensions"] = usedSettings.Dimensions, | ||
["normalize"] = usedSettings.Normalize | ||
}; | ||
|
||
return bodyJson; | ||
} | ||
} |
55 changes: 55 additions & 0 deletions
55
src/Providers/Amazon.Bedrock/src/Embedding/Settings/AmazonV2EmbeddingSettings.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock; | ||
|
||
public class AmazonV2EmbeddingSettings : BedrockEmbeddingSettings | ||
{ | ||
public new static AmazonV2EmbeddingSettings Default { get; } = new() | ||
{ | ||
Dimensions = 1024, | ||
MaximumInputLength = 10_000, | ||
Normalize = true | ||
}; | ||
|
||
public bool Normalize { get; set; } | ||
|
||
/// <summary> | ||
/// Calculate the settings to use for the request. | ||
/// </summary> | ||
/// <param name="requestSettings"></param> | ||
/// <param name="modelSettings"></param> | ||
/// <param name="providerSettings"></param> | ||
/// <returns></returns> | ||
/// <exception cref="InvalidOperationException"></exception> | ||
public new static AmazonV2EmbeddingSettings Calculate( | ||
EmbeddingSettings? requestSettings, | ||
EmbeddingSettings? modelSettings, | ||
EmbeddingSettings? providerSettings) | ||
{ | ||
var requestSettingsCasted = requestSettings as AmazonV2EmbeddingSettings; | ||
var modelSettingsCasted = modelSettings as AmazonV2EmbeddingSettings; | ||
var providerSettingsCasted = providerSettings as AmazonV2EmbeddingSettings; | ||
|
||
return new AmazonV2EmbeddingSettings | ||
{ | ||
Dimensions = | ||
requestSettingsCasted?.Dimensions ?? | ||
modelSettingsCasted?.Dimensions ?? | ||
providerSettingsCasted?.Dimensions ?? | ||
Default.Dimensions ?? | ||
throw new InvalidOperationException("Default Dimensions is not set."), | ||
|
||
MaximumInputLength = | ||
requestSettingsCasted?.MaximumInputLength ?? | ||
modelSettingsCasted?.MaximumInputLength ?? | ||
providerSettingsCasted?.MaximumInputLength ?? | ||
Default.MaximumInputLength ?? | ||
throw new InvalidOperationException("Default MaximumInputLength is not set."), | ||
|
||
Normalize = | ||
requestSettingsCasted?.Normalize ?? | ||
modelSettingsCasted?.Normalize ?? | ||
providerSettingsCasted?.Normalize ?? | ||
Default.Normalize, | ||
}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented
CreateEmbeddingsAsync
inAmazonTitanEmbeddingV2Model
with comprehensive error handling and efficient asynchronous operations.Optimize JSON body construction by directly using
JsonObject
properties instead of intermediate conversions.Committable suggestion