-
-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added Amazon Bedrock Cohere R plus and Cohere R models (#285)
fix: fixed small issues
- Loading branch information
Showing
9 changed files
with
279 additions
and
10 deletions.
There are no files selected for viewing
105 changes: 105 additions & 0 deletions
105
src/Providers/Amazon.Bedrock/src/Chat/CohereCommandRModel.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
using System.Diagnostics; | ||
using System.Text; | ||
using System.Text.Json; | ||
using System.Text.Json.Nodes; | ||
using Amazon.BedrockRuntime.Model; | ||
using LangChain.Providers.Amazon.Bedrock.Internal; | ||
|
||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock; | ||
|
||
public abstract class CohereCommandRModel( | ||
BedrockProvider provider, | ||
string id) | ||
: ChatModel(id) | ||
{ | ||
/// <summary> | ||
/// Generates a chat response based on the provided `ChatRequest`. | ||
/// </summary> | ||
/// <param name="request">The `ChatRequest` containing the input messages and other parameters.</param> | ||
/// <param name="settings">Optional `ChatSettings` to override the model's default settings.</param> | ||
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param> | ||
/// <returns>A `ChatResponse` containing the generated messages and usage information.</returns> | ||
public override async Task<ChatResponse> GenerateAsync( | ||
ChatRequest request, | ||
ChatSettings? settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
var prompt = request.Messages.ToSimplePrompt(); | ||
var messages = request.Messages.ToList(); | ||
|
||
var stringBuilder = new StringBuilder(); | ||
|
||
var usedSettings = CohereCommandChatSettings.Calculate( | ||
requestSettings: settings, | ||
modelSettings: Settings, | ||
providerSettings: provider.ChatSettings); | ||
|
||
var bodyJson = CreateBodyJson(prompt, usedSettings); | ||
|
||
if (usedSettings.UseStreaming == true) | ||
{ | ||
var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson); | ||
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false); | ||
|
||
foreach (var payloadPart in response.Body) | ||
{ | ||
var streamEvent = (PayloadPart)payloadPart; | ||
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken) | ||
.ConfigureAwait(false); | ||
var delta = chunk?["text"]?.GetValue<string>() ?? string.Empty; | ||
|
||
OnPartialResponseGenerated(delta); | ||
stringBuilder.Append(delta); | ||
|
||
var finished = chunk?["finish_reason"]?.GetValue<string>() ?? string.Empty; | ||
if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal)) | ||
{ | ||
OnCompletedResponseGenerated(stringBuilder.ToString()); | ||
} | ||
} | ||
} | ||
else | ||
{ | ||
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) | ||
.ConfigureAwait(false); | ||
|
||
var generatedText = response?["text"]?.GetValue<string>() ?? string.Empty; | ||
|
||
messages.Add(generatedText.AsAiMessage()); | ||
OnCompletedResponseGenerated(generatedText); | ||
} | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
AddUsage(usage); | ||
provider.AddUsage(usage); | ||
|
||
return new ChatResponse | ||
{ | ||
Messages = messages, | ||
UsedSettings = usedSettings, | ||
Usage = usage, | ||
}; | ||
} | ||
|
||
/// <summary> | ||
/// Creates the request body JSON for the Cohere model based on the provided prompt and settings. | ||
/// </summary> | ||
/// <param name="prompt">The input prompt for the model.</param> | ||
/// <param name="usedSettings">The settings to use for the request.</param> | ||
/// <returns>A `JsonObject` representing the request body.</returns> | ||
private static JsonObject CreateBodyJson(string prompt, CohereCommandChatSettings usedSettings) | ||
{ | ||
var bodyJson = new JsonObject | ||
{ | ||
["message"] = prompt, | ||
}; | ||
return bodyJson; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
88 changes: 88 additions & 0 deletions
88
src/Providers/Amazon.Bedrock/src/Embedding/AmazonTitanEmbeddingV2Model.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
using System.Diagnostics; | ||
using System.Text.Json.Nodes; | ||
using LangChain.Providers.Amazon.Bedrock.Internal; | ||
|
||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock; | ||
|
||
public abstract class AmazonTitanEmbeddingV2Model( | ||
BedrockProvider provider, | ||
string id) | ||
: Model<EmbeddingSettings>(id), IEmbeddingModel | ||
{ | ||
|
||
/// <summary> | ||
/// Creates embeddings for the input strings using the Amazon model. | ||
/// </summary> | ||
/// <param name="request">The `EmbeddingRequest` containing the input strings.</param> | ||
/// <param name="settings">Optional `EmbeddingSettings` to override the model's default settings.</param> | ||
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param> | ||
/// <returns>An `EmbeddingResponse` containing the generated embeddings and usage information.</returns> | ||
|
||
public async Task<EmbeddingResponse> CreateEmbeddingsAsync( | ||
EmbeddingRequest request, | ||
EmbeddingSettings? settings = null, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
request = request ?? throw new ArgumentNullException(nameof(request)); | ||
|
||
var watch = Stopwatch.StartNew(); | ||
|
||
var usedSettings = AmazonV2EmbeddingSettings.Calculate( | ||
requestSettings: settings, | ||
modelSettings: Settings, | ||
providerSettings: provider.EmbeddingSettings); | ||
|
||
var embeddings = new List<float[]>(capacity: request.Strings.Count); | ||
|
||
var tasks = request.Strings.Select(text => | ||
{ | ||
var bodyJson = CreateBodyJson(text, usedSettings); | ||
return provider.Api.InvokeModelAsync(Id, bodyJson, | ||
cancellationToken); | ||
}) | ||
.ToList(); | ||
var results = await Task.WhenAll(tasks).ConfigureAwait(false); | ||
|
||
foreach (var response in results) | ||
{ | ||
var embedding = response?["embedding"]?.AsArray(); | ||
if (embedding == null) continue; | ||
|
||
var f = new float[(int)usedSettings.Dimensions!]; | ||
for (var i = 0; i < embedding.Count; i++) | ||
{ | ||
f[i] = (float)embedding[(Index)i]?.AsValue()!; | ||
} | ||
|
||
embeddings.Add(f); | ||
} | ||
|
||
var usage = Usage.Empty with | ||
{ | ||
Time = watch.Elapsed, | ||
}; | ||
AddUsage(usage); | ||
provider.AddUsage(usage); | ||
|
||
return new EmbeddingResponse | ||
{ | ||
Values = embeddings.ToArray(), | ||
Usage = Usage.Empty, | ||
UsedSettings = usedSettings, | ||
Dimensions = embeddings.FirstOrDefault()?.Length ?? 0, | ||
}; | ||
} | ||
|
||
private static JsonObject CreateBodyJson(string? prompt, AmazonV2EmbeddingSettings usedSettings) | ||
{ | ||
var bodyJson = new JsonObject | ||
{ | ||
["inputText"] = prompt, | ||
["dimensions"] = usedSettings.Dimensions, | ||
["normalize"] = usedSettings.Normalize | ||
}; | ||
|
||
return bodyJson; | ||
} | ||
} |
55 changes: 55 additions & 0 deletions
55
src/Providers/Amazon.Bedrock/src/Embedding/Settings/AmazonV2EmbeddingSettings.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// ReSharper disable once CheckNamespace | ||
namespace LangChain.Providers.Amazon.Bedrock; | ||
|
||
public class AmazonV2EmbeddingSettings : BedrockEmbeddingSettings | ||
{ | ||
public new static AmazonV2EmbeddingSettings Default { get; } = new() | ||
{ | ||
Dimensions = 1024, | ||
MaximumInputLength = 10_000, | ||
Normalize = true | ||
}; | ||
|
||
public bool Normalize { get; set; } | ||
|
||
/// <summary> | ||
/// Calculate the settings to use for the request. | ||
/// </summary> | ||
/// <param name="requestSettings"></param> | ||
/// <param name="modelSettings"></param> | ||
/// <param name="providerSettings"></param> | ||
/// <returns></returns> | ||
/// <exception cref="InvalidOperationException"></exception> | ||
public new static AmazonV2EmbeddingSettings Calculate( | ||
EmbeddingSettings? requestSettings, | ||
EmbeddingSettings? modelSettings, | ||
EmbeddingSettings? providerSettings) | ||
{ | ||
var requestSettingsCasted = requestSettings as AmazonV2EmbeddingSettings; | ||
var modelSettingsCasted = modelSettings as AmazonV2EmbeddingSettings; | ||
var providerSettingsCasted = providerSettings as AmazonV2EmbeddingSettings; | ||
|
||
return new AmazonV2EmbeddingSettings | ||
{ | ||
Dimensions = | ||
requestSettingsCasted?.Dimensions ?? | ||
modelSettingsCasted?.Dimensions ?? | ||
providerSettingsCasted?.Dimensions ?? | ||
Default.Dimensions ?? | ||
throw new InvalidOperationException("Default Dimensions is not set."), | ||
|
||
MaximumInputLength = | ||
requestSettingsCasted?.MaximumInputLength ?? | ||
modelSettingsCasted?.MaximumInputLength ?? | ||
providerSettingsCasted?.MaximumInputLength ?? | ||
Default.MaximumInputLength ?? | ||
throw new InvalidOperationException("Default MaximumInputLength is not set."), | ||
|
||
Normalize = | ||
requestSettingsCasted?.Normalize ?? | ||
modelSettingsCasted?.Normalize ?? | ||
providerSettingsCasted?.Normalize ?? | ||
Default.Normalize, | ||
}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters