Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added Amazon Bedrock Cohere R plus and Cohere R models #285

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions src/Providers/Amazon.Bedrock/src/Chat/CohereCommandRModel.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using System.Diagnostics;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using Amazon.BedrockRuntime.Model;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public abstract class CohereCommandRModel(
BedrockProvider provider,
string id)
: ChatModel(id)
{
/// <summary>
/// Generates a chat response based on the provided `ChatRequest`.
/// </summary>
/// <param name="request">The `ChatRequest` containing the input messages and other parameters.</param>
/// <param name="settings">Optional `ChatSettings` to override the model's default settings.</param>
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param>
/// <returns>A `ChatResponse` containing the generated messages and usage information.</returns>
public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();
var messages = request.Messages.ToList();

var stringBuilder = new StringBuilder();

var usedSettings = CohereCommandChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);

var bodyJson = CreateBodyJson(prompt, usedSettings);

if (usedSettings.UseStreaming == true)
{
var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson);
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false);

foreach (var payloadPart in response.Body)
{
var streamEvent = (PayloadPart)payloadPart;
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken)
.ConfigureAwait(false);
var delta = chunk?["text"]?.GetValue<string>() ?? string.Empty;

OnPartialResponseGenerated(delta);
stringBuilder.Append(delta);

var finished = chunk?["finish_reason"]?.GetValue<string>() ?? string.Empty;
if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal))
{
OnCompletedResponseGenerated(stringBuilder.ToString());
}
}
}
else
{
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken)
.ConfigureAwait(false);

var generatedText = response?["text"]?.GetValue<string>() ?? string.Empty;

messages.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(generatedText);
}

var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);

return new ChatResponse
{
Messages = messages,
UsedSettings = usedSettings,
Usage = usage,
};
}
Comment on lines +23 to +89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented GenerateAsync in CohereCommandRModel with comprehensive error handling and efficient asynchronous operations.

Optimize JSON body construction by directly using JsonObject properties instead of intermediate conversions.

- var bodyJson = new JsonObject
- {
-     ["message"] = prompt,
- };
+ var bodyJson = new JsonObject
+ {
+     { "message", prompt },
+ };

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();
var messages = request.Messages.ToList();
var stringBuilder = new StringBuilder();
var usedSettings = CohereCommandChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);
var bodyJson = CreateBodyJson(prompt, usedSettings);
if (usedSettings.UseStreaming == true)
{
var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson);
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false);
foreach (var payloadPart in response.Body)
{
var streamEvent = (PayloadPart)payloadPart;
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken)
.ConfigureAwait(false);
var delta = chunk?["text"]?.GetValue<string>() ?? string.Empty;
OnPartialResponseGenerated(delta);
stringBuilder.Append(delta);
var finished = chunk?["finish_reason"]?.GetValue<string>() ?? string.Empty;
if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal))
{
OnCompletedResponseGenerated(stringBuilder.ToString());
}
}
}
else
{
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken)
.ConfigureAwait(false);
var generatedText = response?["text"]?.GetValue<string>() ?? string.Empty;
messages.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(generatedText);
}
var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);
return new ChatResponse
{
Messages = messages,
UsedSettings = usedSettings,
Usage = usage,
};
}
public override async Task<ChatResponse> GenerateAsync(
ChatRequest request,
ChatSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
var watch = Stopwatch.StartNew();
var prompt = request.Messages.ToSimplePrompt();
var messages = request.Messages.ToList();
var stringBuilder = new StringBuilder();
var usedSettings = CohereCommandChatSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.ChatSettings);
var bodyJson = new JsonObject
{
{ "message", prompt },
};
if (usedSettings.UseStreaming == true)
{
var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson);
var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false);
foreach (var payloadPart in response.Body)
{
var streamEvent = (PayloadPart)payloadPart;
var chunk = await JsonSerializer.DeserializeAsync<JsonObject>(streamEvent.Bytes, cancellationToken: cancellationToken)
.ConfigureAwait(false);
var delta = chunk?["text"]?.GetValue<string>() ?? string.Empty;
OnPartialResponseGenerated(delta);
stringBuilder.Append(delta);
var finished = chunk?["finish_reason"]?.GetValue<string>() ?? string.Empty;
if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal))
{
OnCompletedResponseGenerated(stringBuilder.ToString());
}
}
}
else
{
var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken)
.ConfigureAwait(false);
var generatedText = response?["text"]?.GetValue<string>() ?? string.Empty;
messages.Add(generatedText.AsAiMessage());
OnCompletedResponseGenerated(generatedText);
}
var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);
return new ChatResponse
{
Messages = messages,
UsedSettings = usedSettings,
Usage = usage,
};
}


/// <summary>
/// Creates the request body JSON for the Cohere model based on the provided prompt and settings.
/// </summary>
/// <param name="prompt">The input prompt for the model.</param>
/// <param name="usedSettings">The settings to use for the request.</param>
/// <returns>A `JsonObject` representing the request body.</returns>
private static JsonObject CreateBodyJson(string prompt, CohereCommandChatSettings usedSettings)
{
var bodyJson = new JsonObject
{
["message"] = prompt,
};
return bodyJson;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class MetaLlama2ChatSettings : BedrockChatSettings
User = ChatSettings.Default.User,
UseStreaming = false,
Temperature = 0.5,
MaxTokens = 4000,
MaxTokens = 2048,
TopP = 0.9,
TopK = 0.0
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using System.Diagnostics;
using System.Text.Json.Nodes;
using LangChain.Providers.Amazon.Bedrock.Internal;

// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public abstract class AmazonTitanEmbeddingV2Model(
BedrockProvider provider,
string id)
: Model<EmbeddingSettings>(id), IEmbeddingModel
{

/// <summary>
/// Creates embeddings for the input strings using the Amazon model.
/// </summary>
/// <param name="request">The `EmbeddingRequest` containing the input strings.</param>
/// <param name="settings">Optional `EmbeddingSettings` to override the model's default settings.</param>
/// <param name="cancellationToken">A cancellation token to cancel the operation.</param>
/// <returns>An `EmbeddingResponse` containing the generated embeddings and usage information.</returns>

public async Task<EmbeddingResponse> CreateEmbeddingsAsync(
EmbeddingRequest request,
EmbeddingSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));

var watch = Stopwatch.StartNew();

var usedSettings = AmazonV2EmbeddingSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.EmbeddingSettings);

var embeddings = new List<float[]>(capacity: request.Strings.Count);

var tasks = request.Strings.Select(text =>
{
var bodyJson = CreateBodyJson(text, usedSettings);
return provider.Api.InvokeModelAsync(Id, bodyJson,
cancellationToken);
})
.ToList();
var results = await Task.WhenAll(tasks).ConfigureAwait(false);

foreach (var response in results)
{
var embedding = response?["embedding"]?.AsArray();
if (embedding == null) continue;

var f = new float[(int)usedSettings.Dimensions!];
for (var i = 0; i < embedding.Count; i++)
{
f[i] = (float)embedding[(Index)i]?.AsValue()!;
}

embeddings.Add(f);
}

var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);

return new EmbeddingResponse
{
Values = embeddings.ToArray(),
Usage = Usage.Empty,
UsedSettings = usedSettings,
Dimensions = embeddings.FirstOrDefault()?.Length ?? 0,
};
}
Comment on lines +22 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented CreateEmbeddingsAsync in AmazonTitanEmbeddingV2Model with comprehensive error handling and efficient asynchronous operations.

Optimize JSON body construction by directly using JsonObject properties instead of intermediate conversions.

- var bodyJson = new JsonObject
- {
-     ["inputText"] = prompt,
-     ["dimensions"] = usedSettings.Dimensions,
-     ["normalize"] = usedSettings.Normalize
- };
+ var bodyJson = new JsonObject
+ {
+     { "inputText", prompt },
+     { "dimensions", usedSettings.Dimensions },
+     { "normalize", usedSettings.Normalize }
+ };

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
public async Task<EmbeddingResponse> CreateEmbeddingsAsync(
EmbeddingRequest request,
EmbeddingSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
var watch = Stopwatch.StartNew();
var usedSettings = AmazonV2EmbeddingSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.EmbeddingSettings);
var embeddings = new List<float[]>(capacity: request.Strings.Count);
var tasks = request.Strings.Select(text =>
{
var bodyJson = CreateBodyJson(text, usedSettings);
return provider.Api.InvokeModelAsync(Id, bodyJson,
cancellationToken);
})
.ToList();
var results = await Task.WhenAll(tasks).ConfigureAwait(false);
foreach (var response in results)
{
var embedding = response?["embedding"]?.AsArray();
if (embedding == null) continue;
var f = new float[(int)usedSettings.Dimensions!];
for (var i = 0; i < embedding.Count; i++)
{
f[i] = (float)embedding[(Index)i]?.AsValue()!;
}
embeddings.Add(f);
}
var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);
return new EmbeddingResponse
{
Values = embeddings.ToArray(),
Usage = Usage.Empty,
UsedSettings = usedSettings,
Dimensions = embeddings.FirstOrDefault()?.Length ?? 0,
};
}
public async Task<EmbeddingResponse> CreateEmbeddingsAsync(
EmbeddingRequest request,
EmbeddingSettings? settings = null,
CancellationToken cancellationToken = default)
{
request = request ?? throw new ArgumentNullException(nameof(request));
var watch = Stopwatch.StartNew();
var usedSettings = AmazonV2EmbeddingSettings.Calculate(
requestSettings: settings,
modelSettings: Settings,
providerSettings: provider.EmbeddingSettings);
var embeddings = new List<float[]>(capacity: request.Strings.Count);
var tasks = request.Strings.Select(text =>
{
var bodyJson = new JsonObject
{
{ "inputText", text },
{ "dimensions", usedSettings.Dimensions },
{ "normalize", usedSettings.Normalize }
};
return provider.Api.InvokeModelAsync(Id, bodyJson,
cancellationToken);
})
.ToList();
var results = await Task.WhenAll(tasks).ConfigureAwait(false);
foreach (var response in results)
{
var embedding = response?["embedding"]?.AsArray();
if (embedding == null) continue;
var f = new float[(int)usedSettings.Dimensions!];
for (var i = 0; i < embedding.Count; i++)
{
f[i] = (float)embedding[(Index)i]?.AsValue()!;
}
embeddings.Add(f);
}
var usage = Usage.Empty with
{
Time = watch.Elapsed,
};
AddUsage(usage);
provider.AddUsage(usage);
return new EmbeddingResponse
{
Values = embeddings.ToArray(),
Usage = Usage.Empty,
UsedSettings = usedSettings,
Dimensions = embeddings.FirstOrDefault()?.Length ?? 0,
};
}


private static JsonObject CreateBodyJson(string? prompt, AmazonV2EmbeddingSettings usedSettings)
{
var bodyJson = new JsonObject
{
["inputText"] = prompt,
["dimensions"] = usedSettings.Dimensions,
["normalize"] = usedSettings.Normalize
};

return bodyJson;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// ReSharper disable once CheckNamespace
namespace LangChain.Providers.Amazon.Bedrock;

public class AmazonV2EmbeddingSettings : BedrockEmbeddingSettings
{
public new static AmazonV2EmbeddingSettings Default { get; } = new()
{
Dimensions = 1024,
MaximumInputLength = 10_000,
Normalize = true
};

public bool Normalize { get; set; }

/// <summary>
/// Calculate the settings to use for the request.
/// </summary>
/// <param name="requestSettings"></param>
/// <param name="modelSettings"></param>
/// <param name="providerSettings"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
public new static AmazonV2EmbeddingSettings Calculate(
EmbeddingSettings? requestSettings,
EmbeddingSettings? modelSettings,
EmbeddingSettings? providerSettings)
{
var requestSettingsCasted = requestSettings as AmazonV2EmbeddingSettings;
var modelSettingsCasted = modelSettings as AmazonV2EmbeddingSettings;
var providerSettingsCasted = providerSettings as AmazonV2EmbeddingSettings;

return new AmazonV2EmbeddingSettings
{
Dimensions =
requestSettingsCasted?.Dimensions ??
modelSettingsCasted?.Dimensions ??
providerSettingsCasted?.Dimensions ??
Default.Dimensions ??
throw new InvalidOperationException("Default Dimensions is not set."),

MaximumInputLength =
requestSettingsCasted?.MaximumInputLength ??
modelSettingsCasted?.MaximumInputLength ??
providerSettingsCasted?.MaximumInputLength ??
Default.MaximumInputLength ??
throw new InvalidOperationException("Default MaximumInputLength is not set."),

Normalize =
requestSettingsCasted?.Normalize ??
modelSettingsCasted?.Normalize ??
providerSettingsCasted?.Normalize ??
Default.Normalize,
};
}
}
4 changes: 4 additions & 0 deletions src/Providers/Amazon.Bedrock/src/Predefined/Amazon.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ public class TitanTextLiteV1Model(BedrockProvider provider)
public class TitanEmbedTextV1Model(BedrockProvider provider)
: AmazonTitanEmbeddingModel(provider, id: "amazon.titan-embed-text-v1");

/// <inheritdoc />
public class TitanEmbedTextV2Model(BedrockProvider provider)
: AmazonTitanEmbeddingV2Model(provider, id: "amazon.titan-embed-text-v2:0");

/// <inheritdoc />
public class TitanEmbedImageV1Model(BedrockProvider provider)
: AmazonTitanImageEmbeddingModel(provider, id: "amazon.titan-embed-image-v1");
Expand Down
8 changes: 8 additions & 0 deletions src/Providers/Amazon.Bedrock/src/Predefined/Cohere.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ public class EmbedEnglishV3Model(BedrockProvider provider)
/// <inheritdoc />
public class EmbedMultilingualV3Model(BedrockProvider provider)
: CohereEmbeddingModel(provider, id: "cohere.embed-multilingual-v3");

/// <inheritdoc />
public class CommandRPlusModel(BedrockProvider provider)
: CohereCommandRModel(provider, id: "cohere.command-r-plus-v1:0");

/// <inheritdoc />
public class CommandRModel(BedrockProvider provider)
: CohereCommandRModel(provider, id: "cohere.command-r-v1:0");
2 changes: 1 addition & 1 deletion src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public class Mistral7BInstruct(BedrockProvider provider)

/// <inheritdoc />
public class Mistral8x7BInstruct(BedrockProvider provider)
: MistralModel(provider, id: "mistral.mistral-8x7b-instruct-v0:1");
: MistralModel(provider, id: "mistral.mixtral-8x7b-instruct-v0:1");

/// <inheritdoc />
public class MistralLarge(BedrockProvider provider)
Expand Down
7 changes: 4 additions & 3 deletions src/Providers/Amazon.Bedrock/test/BedrockTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
using LangChain.Providers.Amazon.Bedrock.Predefined.Anthropic;
using LangChain.Providers.Amazon.Bedrock.Predefined.Cohere;
using LangChain.Providers.Amazon.Bedrock.Predefined.Meta;
using LangChain.Providers.Amazon.Bedrock.Predefined.Mistral;
using LangChain.Providers.Amazon.Bedrock.Predefined.Stability;
using LangChain.Schema;
using LangChain.Splitters.Text;
using static LangChain.Chains.Chain;
using Microsoft.SemanticKernel.AI.Embeddings;

namespace LangChain.Providers.Amazon.Bedrock.Tests;

Expand All @@ -31,7 +31,7 @@ public async Task Chains()
//var llm = new Jurassic2MidModel(provider);
//var llm = new ClaudeV21Model(provider);
//var llm = new Mistral7BInstruct(provider);
var llm = new Claude3SonnetModel(provider);
var llm = new CommandRModel(provider);

var template = "What is a good name for a company that makes {product}?";
var prompt = new PromptTemplate(new PromptTemplateInput(template, new List<string>(1) { "product" }));
Expand Down Expand Up @@ -271,7 +271,7 @@ public async Task ClaudeImageToText()
public async Task SimpleTest(bool useStreaming, bool useChatSettings)
{
var provider = new BedrockProvider();
var llm = new CommandLightTextV14Model(provider);
var llm = new CommandRModel(provider);

llm.PromptSent += (_, prompt) => Console.WriteLine($"Prompt: {prompt}");
llm.PartialResponseGenerated += (_, delta) => Console.Write(delta);
Expand All @@ -290,6 +290,7 @@ you are a comic book writer. you will be given a question and you will answer i
{
var response = await llm.GenerateAsync(prompt);
response.LastMessageContent.Should().NotBeNull();
Console.WriteLine(response.LastMessageContent);
}
}

Expand Down
18 changes: 13 additions & 5 deletions src/Providers/Amazon.Bedrock/test/BedrockTextModelTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.CSharp;
using System.Reflection;
using Amazon;

namespace LangChain.Providers.Amazon.Bedrock.Tests;

Expand All @@ -20,7 +21,7 @@ public Task TestAllTextLLMs()

var derivedTypeNames = FindDerivedTypes(predefinedDir);

var provider = new BedrockProvider();
var provider = new BedrockProvider(RegionEndpoint.USWest2);

var failedTypes = new List<string>();
var workingTypes = new Dictionary<string, double>();
Expand All @@ -44,13 +45,20 @@ public Task TestAllTextLLMs()

Console.WriteLine($"############## {type.FullName}");

object[] args = { provider };
object[] args = [provider];
var llm = (ChatModel)Activator.CreateInstance(type, args)!;
var result = llm.GenerateAsync("who's your favor superhero?");
try
{
var result = llm.GenerateAsync("who's your favorite superhero?");

workingTypes.Add(className, result.Result.Usage.Time.TotalSeconds);
workingTypes.Add(className, result.Result.Usage.Time.TotalSeconds);

Console.WriteLine(result.Result + "\n\n\n");
Console.WriteLine(result.Result + "\n\n\n");
}
catch (Exception e)
{
Console.WriteLine($"**** **** **** ERROR: " + e);
}
}
}
catch (Exception e)
Expand Down