From f583b48a294c8b527e14088e43774ef2e7213314 Mon Sep 17 00:00:00 2001 From: Ty Augustine Date: Thu, 2 May 2024 14:14:28 -0400 Subject: [PATCH] feat: added Amazon Bedrock Cohere R plus and Cohere R models fix: fixed small issues --- .../src/Chat/CohereCommandRModel.cs | 105 ++++++++++++++++++ .../Chat/Settings/MetaLlama2ChatSettings.cs | 2 +- .../Embedding/AmazonTitanEmbeddingV2Model.cs | 88 +++++++++++++++ .../Settings/AmazonV2EmbeddingSettings.cs | 55 +++++++++ .../Amazon.Bedrock/src/Predefined/Amazon.cs | 4 + .../Amazon.Bedrock/src/Predefined/Cohere.cs | 8 ++ .../Amazon.Bedrock/src/Predefined/Mistral.cs | 2 +- .../Amazon.Bedrock/test/BedrockTests.cs | 7 +- .../test/BedrockTextModelTests.cs | 18 ++- 9 files changed, 279 insertions(+), 10 deletions(-) create mode 100644 src/Providers/Amazon.Bedrock/src/Chat/CohereCommandRModel.cs create mode 100644 src/Providers/Amazon.Bedrock/src/Embedding/AmazonTitanEmbeddingV2Model.cs create mode 100644 src/Providers/Amazon.Bedrock/src/Embedding/Settings/AmazonV2EmbeddingSettings.cs diff --git a/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandRModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandRModel.cs new file mode 100644 index 00000000..a38a1889 --- /dev/null +++ b/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandRModel.cs @@ -0,0 +1,105 @@ +using System.Diagnostics; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; +using LangChain.Providers.Amazon.Bedrock.Internal; + +// ReSharper disable once CheckNamespace +namespace LangChain.Providers.Amazon.Bedrock; + +public abstract class CohereCommandRModel( + BedrockProvider provider, + string id) + : ChatModel(id) +{ + /// + /// Generates a chat response based on the provided `ChatRequest`. + /// + /// The `ChatRequest` containing the input messages and other parameters. + /// Optional `ChatSettings` to override the model's default settings. + /// A cancellation token to cancel the operation. + /// A `ChatResponse` containing the generated messages and usage information. + public override async Task GenerateAsync( + ChatRequest request, + ChatSettings? settings = null, + CancellationToken cancellationToken = default) + { + request = request ?? throw new ArgumentNullException(nameof(request)); + + var watch = Stopwatch.StartNew(); + var prompt = request.Messages.ToSimplePrompt(); + var messages = request.Messages.ToList(); + + var stringBuilder = new StringBuilder(); + + var usedSettings = CohereCommandChatSettings.Calculate( + requestSettings: settings, + modelSettings: Settings, + providerSettings: provider.ChatSettings); + + var bodyJson = CreateBodyJson(prompt, usedSettings); + + if (usedSettings.UseStreaming == true) + { + var streamRequest = BedrockModelRequest.CreateStreamRequest(Id, bodyJson); + var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken).ConfigureAwait(false); + + foreach (var payloadPart in response.Body) + { + var streamEvent = (PayloadPart)payloadPart; + var chunk = await JsonSerializer.DeserializeAsync(streamEvent.Bytes, cancellationToken: cancellationToken) + .ConfigureAwait(false); + var delta = chunk?["text"]?.GetValue() ?? string.Empty; + + OnPartialResponseGenerated(delta); + stringBuilder.Append(delta); + + var finished = chunk?["finish_reason"]?.GetValue() ?? string.Empty; + if (string.Equals(finished.ToUpperInvariant(), "COMPLETE", StringComparison.Ordinal)) + { + OnCompletedResponseGenerated(stringBuilder.ToString()); + } + } + } + else + { + var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) + .ConfigureAwait(false); + + var generatedText = response?["text"]?.GetValue() ?? string.Empty; + + messages.Add(generatedText.AsAiMessage()); + OnCompletedResponseGenerated(generatedText); + } + + var usage = Usage.Empty with + { + Time = watch.Elapsed, + }; + AddUsage(usage); + provider.AddUsage(usage); + + return new ChatResponse + { + Messages = messages, + UsedSettings = usedSettings, + Usage = usage, + }; + } + + /// + /// Creates the request body JSON for the Cohere model based on the provided prompt and settings. + /// + /// The input prompt for the model. + /// The settings to use for the request. + /// A `JsonObject` representing the request body. + private static JsonObject CreateBodyJson(string prompt, CohereCommandChatSettings usedSettings) + { + var bodyJson = new JsonObject + { + ["message"] = prompt, + }; + return bodyJson; + } +} \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Chat/Settings/MetaLlama2ChatSettings.cs b/src/Providers/Amazon.Bedrock/src/Chat/Settings/MetaLlama2ChatSettings.cs index 9b82fefd..2ab7a360 100644 --- a/src/Providers/Amazon.Bedrock/src/Chat/Settings/MetaLlama2ChatSettings.cs +++ b/src/Providers/Amazon.Bedrock/src/Chat/Settings/MetaLlama2ChatSettings.cs @@ -9,7 +9,7 @@ public class MetaLlama2ChatSettings : BedrockChatSettings User = ChatSettings.Default.User, UseStreaming = false, Temperature = 0.5, - MaxTokens = 4000, + MaxTokens = 2048, TopP = 0.9, TopK = 0.0 }; diff --git a/src/Providers/Amazon.Bedrock/src/Embedding/AmazonTitanEmbeddingV2Model.cs b/src/Providers/Amazon.Bedrock/src/Embedding/AmazonTitanEmbeddingV2Model.cs new file mode 100644 index 00000000..9293d3cd --- /dev/null +++ b/src/Providers/Amazon.Bedrock/src/Embedding/AmazonTitanEmbeddingV2Model.cs @@ -0,0 +1,88 @@ +using System.Diagnostics; +using System.Text.Json.Nodes; +using LangChain.Providers.Amazon.Bedrock.Internal; + +// ReSharper disable once CheckNamespace +namespace LangChain.Providers.Amazon.Bedrock; + +public abstract class AmazonTitanEmbeddingV2Model( + BedrockProvider provider, + string id) + : Model(id), IEmbeddingModel +{ + + /// + /// Creates embeddings for the input strings using the Amazon model. + /// + /// The `EmbeddingRequest` containing the input strings. + /// Optional `EmbeddingSettings` to override the model's default settings. + /// A cancellation token to cancel the operation. + /// An `EmbeddingResponse` containing the generated embeddings and usage information. + + public async Task CreateEmbeddingsAsync( + EmbeddingRequest request, + EmbeddingSettings? settings = null, + CancellationToken cancellationToken = default) + { + request = request ?? throw new ArgumentNullException(nameof(request)); + + var watch = Stopwatch.StartNew(); + + var usedSettings = AmazonV2EmbeddingSettings.Calculate( + requestSettings: settings, + modelSettings: Settings, + providerSettings: provider.EmbeddingSettings); + + var embeddings = new List(capacity: request.Strings.Count); + + var tasks = request.Strings.Select(text => + { + var bodyJson = CreateBodyJson(text, usedSettings); + return provider.Api.InvokeModelAsync(Id, bodyJson, + cancellationToken); + }) + .ToList(); + var results = await Task.WhenAll(tasks).ConfigureAwait(false); + + foreach (var response in results) + { + var embedding = response?["embedding"]?.AsArray(); + if (embedding == null) continue; + + var f = new float[(int)usedSettings.Dimensions!]; + for (var i = 0; i < embedding.Count; i++) + { + f[i] = (float)embedding[(Index)i]?.AsValue()!; + } + + embeddings.Add(f); + } + + var usage = Usage.Empty with + { + Time = watch.Elapsed, + }; + AddUsage(usage); + provider.AddUsage(usage); + + return new EmbeddingResponse + { + Values = embeddings.ToArray(), + Usage = Usage.Empty, + UsedSettings = usedSettings, + Dimensions = embeddings.FirstOrDefault()?.Length ?? 0, + }; + } + + private static JsonObject CreateBodyJson(string? prompt, AmazonV2EmbeddingSettings usedSettings) + { + var bodyJson = new JsonObject + { + ["inputText"] = prompt, + ["dimensions"] = usedSettings.Dimensions, + ["normalize"] = usedSettings.Normalize + }; + + return bodyJson; + } +} \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Embedding/Settings/AmazonV2EmbeddingSettings.cs b/src/Providers/Amazon.Bedrock/src/Embedding/Settings/AmazonV2EmbeddingSettings.cs new file mode 100644 index 00000000..bbb61c70 --- /dev/null +++ b/src/Providers/Amazon.Bedrock/src/Embedding/Settings/AmazonV2EmbeddingSettings.cs @@ -0,0 +1,55 @@ +// ReSharper disable once CheckNamespace +namespace LangChain.Providers.Amazon.Bedrock; + +public class AmazonV2EmbeddingSettings : BedrockEmbeddingSettings +{ + public new static AmazonV2EmbeddingSettings Default { get; } = new() + { + Dimensions = 1024, + MaximumInputLength = 10_000, + Normalize = true + }; + + public bool Normalize { get; set; } + + /// + /// Calculate the settings to use for the request. + /// + /// + /// + /// + /// + /// + public new static AmazonV2EmbeddingSettings Calculate( + EmbeddingSettings? requestSettings, + EmbeddingSettings? modelSettings, + EmbeddingSettings? providerSettings) + { + var requestSettingsCasted = requestSettings as AmazonV2EmbeddingSettings; + var modelSettingsCasted = modelSettings as AmazonV2EmbeddingSettings; + var providerSettingsCasted = providerSettings as AmazonV2EmbeddingSettings; + + return new AmazonV2EmbeddingSettings + { + Dimensions = + requestSettingsCasted?.Dimensions ?? + modelSettingsCasted?.Dimensions ?? + providerSettingsCasted?.Dimensions ?? + Default.Dimensions ?? + throw new InvalidOperationException("Default Dimensions is not set."), + + MaximumInputLength = + requestSettingsCasted?.MaximumInputLength ?? + modelSettingsCasted?.MaximumInputLength ?? + providerSettingsCasted?.MaximumInputLength ?? + Default.MaximumInputLength ?? + throw new InvalidOperationException("Default MaximumInputLength is not set."), + + Normalize = + requestSettingsCasted?.Normalize ?? + modelSettingsCasted?.Normalize ?? + providerSettingsCasted?.Normalize ?? + Default.Normalize, + }; + } +} \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Predefined/Amazon.cs b/src/Providers/Amazon.Bedrock/src/Predefined/Amazon.cs index bdbcac7a..3f27c93d 100644 --- a/src/Providers/Amazon.Bedrock/src/Predefined/Amazon.cs +++ b/src/Providers/Amazon.Bedrock/src/Predefined/Amazon.cs @@ -13,6 +13,10 @@ public class TitanTextLiteV1Model(BedrockProvider provider) public class TitanEmbedTextV1Model(BedrockProvider provider) : AmazonTitanEmbeddingModel(provider, id: "amazon.titan-embed-text-v1"); +/// +public class TitanEmbedTextV2Model(BedrockProvider provider) + : AmazonTitanEmbeddingV2Model(provider, id: "amazon.titan-embed-text-v2:0"); + /// public class TitanEmbedImageV1Model(BedrockProvider provider) : AmazonTitanImageEmbeddingModel(provider, id: "amazon.titan-embed-image-v1"); diff --git a/src/Providers/Amazon.Bedrock/src/Predefined/Cohere.cs b/src/Providers/Amazon.Bedrock/src/Predefined/Cohere.cs index 683a6e61..1c572814 100644 --- a/src/Providers/Amazon.Bedrock/src/Predefined/Cohere.cs +++ b/src/Providers/Amazon.Bedrock/src/Predefined/Cohere.cs @@ -16,3 +16,11 @@ public class EmbedEnglishV3Model(BedrockProvider provider) /// public class EmbedMultilingualV3Model(BedrockProvider provider) : CohereEmbeddingModel(provider, id: "cohere.embed-multilingual-v3"); + +/// +public class CommandRPlusModel(BedrockProvider provider) + : CohereCommandRModel(provider, id: "cohere.command-r-plus-v1:0"); + +/// +public class CommandRModel(BedrockProvider provider) + : CohereCommandRModel(provider, id: "cohere.command-r-v1:0"); diff --git a/src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs b/src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs index f551fdfa..05357e64 100644 --- a/src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs +++ b/src/Providers/Amazon.Bedrock/src/Predefined/Mistral.cs @@ -7,7 +7,7 @@ public class Mistral7BInstruct(BedrockProvider provider) /// public class Mistral8x7BInstruct(BedrockProvider provider) - : MistralModel(provider, id: "mistral.mistral-8x7b-instruct-v0:1"); + : MistralModel(provider, id: "mistral.mixtral-8x7b-instruct-v0:1"); /// public class MistralLarge(BedrockProvider provider) diff --git a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs index cb48f091..e77679be 100644 --- a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs +++ b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs @@ -13,11 +13,11 @@ using LangChain.Providers.Amazon.Bedrock.Predefined.Anthropic; using LangChain.Providers.Amazon.Bedrock.Predefined.Cohere; using LangChain.Providers.Amazon.Bedrock.Predefined.Meta; +using LangChain.Providers.Amazon.Bedrock.Predefined.Mistral; using LangChain.Providers.Amazon.Bedrock.Predefined.Stability; using LangChain.Schema; using LangChain.Splitters.Text; using static LangChain.Chains.Chain; -using Microsoft.SemanticKernel.AI.Embeddings; namespace LangChain.Providers.Amazon.Bedrock.Tests; @@ -31,7 +31,7 @@ public async Task Chains() //var llm = new Jurassic2MidModel(provider); //var llm = new ClaudeV21Model(provider); //var llm = new Mistral7BInstruct(provider); - var llm = new Claude3SonnetModel(provider); + var llm = new CommandRModel(provider); var template = "What is a good name for a company that makes {product}?"; var prompt = new PromptTemplate(new PromptTemplateInput(template, new List(1) { "product" })); @@ -271,7 +271,7 @@ public async Task ClaudeImageToText() public async Task SimpleTest(bool useStreaming, bool useChatSettings) { var provider = new BedrockProvider(); - var llm = new CommandLightTextV14Model(provider); + var llm = new CommandRModel(provider); llm.PromptSent += (_, prompt) => Console.WriteLine($"Prompt: {prompt}"); llm.PartialResponseGenerated += (_, delta) => Console.Write(delta); @@ -290,6 +290,7 @@ you are a comic book writer. you will be given a question and you will answer i { var response = await llm.GenerateAsync(prompt); response.LastMessageContent.Should().NotBeNull(); + Console.WriteLine(response.LastMessageContent); } } diff --git a/src/Providers/Amazon.Bedrock/test/BedrockTextModelTests.cs b/src/Providers/Amazon.Bedrock/test/BedrockTextModelTests.cs index 84d2797b..1fa1c198 100644 --- a/src/Providers/Amazon.Bedrock/test/BedrockTextModelTests.cs +++ b/src/Providers/Amazon.Bedrock/test/BedrockTextModelTests.cs @@ -1,6 +1,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.CSharp; using System.Reflection; +using Amazon; namespace LangChain.Providers.Amazon.Bedrock.Tests; @@ -20,7 +21,7 @@ public Task TestAllTextLLMs() var derivedTypeNames = FindDerivedTypes(predefinedDir); - var provider = new BedrockProvider(); + var provider = new BedrockProvider(RegionEndpoint.USWest2); var failedTypes = new List(); var workingTypes = new Dictionary(); @@ -44,13 +45,20 @@ public Task TestAllTextLLMs() Console.WriteLine($"############## {type.FullName}"); - object[] args = { provider }; + object[] args = [provider]; var llm = (ChatModel)Activator.CreateInstance(type, args)!; - var result = llm.GenerateAsync("who's your favor superhero?"); + try + { + var result = llm.GenerateAsync("who's your favorite superhero?"); - workingTypes.Add(className, result.Result.Usage.Time.TotalSeconds); + workingTypes.Add(className, result.Result.Usage.Time.TotalSeconds); - Console.WriteLine(result.Result + "\n\n\n"); + Console.WriteLine(result.Result + "\n\n\n"); + } + catch (Exception e) + { + Console.WriteLine($"**** **** **** ERROR: " + e); + } } } catch (Exception e)