From e60802cc28f36135ded8b4978dbe920847a84ca5 Mon Sep 17 00:00:00 2001 From: Ty Augustine Date: Tue, 27 Feb 2024 21:59:32 -0500 Subject: [PATCH 1/4] feat: Added Bedrock streaming --- .../src/BedrockModelStreamRequest.cs | 26 ++++++ .../src/Chat/AmazonTitanChatModel.cs | 82 +++++++++++++----- .../src/Chat/AnthropicClaudeChatModel.cs | 85 ++++++++++++++----- .../src/Chat/BedrockChatSettings.cs | 9 +- .../src/Chat/CohereCommandChatModel.cs | 66 +++++++++++--- .../src/Chat/MetaLlama2ChatModel.cs | 67 +++++++++++---- .../Amazon.Bedrock/test/BedrockTests.cs | 46 ++++++++-- 7 files changed, 303 insertions(+), 78 deletions(-) create mode 100644 src/Providers/Amazon.Bedrock/src/BedrockModelStreamRequest.cs diff --git a/src/Providers/Amazon.Bedrock/src/BedrockModelStreamRequest.cs b/src/Providers/Amazon.Bedrock/src/BedrockModelStreamRequest.cs new file mode 100644 index 00000000..0413ebaa --- /dev/null +++ b/src/Providers/Amazon.Bedrock/src/BedrockModelStreamRequest.cs @@ -0,0 +1,26 @@ +using System.Text; +using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; + +namespace LangChain.Providers.Amazon.Bedrock; + +internal record BedrockModelStreamRequest +{ + public static InvokeModelWithResponseStreamRequest Create(string modelId, JsonObject bodyJson) + { + bodyJson = bodyJson ?? throw new ArgumentNullException(nameof(bodyJson)); + + var byteArray = Encoding.UTF8.GetBytes(bodyJson.ToJsonString()); + var stream = new MemoryStream(byteArray); + + var bedrockRequest = new InvokeModelWithResponseStreamRequest + { + ModelId = modelId, + ContentType = "application/json", + Accept = "application/json", + Body = stream + }; + + return bedrockRequest; + } +} \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Chat/AmazonTitanChatModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/AmazonTitanChatModel.cs index 17921300..fea51c27 100644 --- a/src/Providers/Amazon.Bedrock/src/Chat/AmazonTitanChatModel.cs +++ b/src/Providers/Amazon.Bedrock/src/Chat/AmazonTitanChatModel.cs @@ -1,6 +1,8 @@ using System.Diagnostics; -using System.Linq; +using System.Text; +using System.Text.Json; using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; using LangChain.Providers.Amazon.Bedrock.Internal; // ReSharper disable once CheckNamespace @@ -11,8 +13,6 @@ public abstract class AmazonTitanChatModel( string id) : ChatModel(id) { - public override int ContextLength => 4096; - public override async Task GenerateAsync( ChatRequest request, ChatSettings? settings = null, @@ -22,32 +22,60 @@ public override async Task GenerateAsync( var watch = Stopwatch.StartNew(); var prompt = request.Messages.ToSimplePrompt(); + var messages = request.Messages.ToList(); + + var stringBuilder = new StringBuilder(); var usedSettings = BedrockChatSettings.Calculate( requestSettings: settings, modelSettings: Settings, providerSettings: provider.ChatSettings); - var response = await provider.Api.InvokeModelAsync( - Id, - new JsonObject + + var bodyJson = CreateBodyJson(prompt, usedSettings); + + if (usedSettings.UseStreaming == true) + { + var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson); + var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken); + + foreach (var payloadPart in response.Body) { - ["inputText"] = prompt, - ["textGenerationConfig"] = new JsonObject + var streamEvent = (PayloadPart)payloadPart; + var chunk = await JsonSerializer.DeserializeAsync(streamEvent.Bytes, cancellationToken: cancellationToken) + .ConfigureAwait(false); + var delta = chunk?["outputText"]!.GetValue(); + + OnPartialResponseGenerated(delta!); + stringBuilder.Append(delta); + + var finished = chunk?["completionReason"]?.GetValue(); + if (finished?.ToLower() == "finish") { - ["maxTokenCount"] = usedSettings.MaxTokens!.Value, - ["temperature"] = usedSettings.Temperature!.Value, - ["topP"] = usedSettings.TopP!.Value, - ["stopSequences"] = usedSettings.StopSequences!.AsArray() + OnCompletedResponseGenerated(stringBuilder.ToString()); } - }, - cancellationToken).ConfigureAwait(false); + } + + OnPartialResponseGenerated(Environment.NewLine); + stringBuilder.Append(Environment.NewLine); - var generatedText = response?["results"]?[0]?["outputText"]?.GetValue() ?? string.Empty; + var newMessage = new Message( + Content: stringBuilder.ToString(), + Role: MessageRole.Ai); + messages.Add(newMessage); - var result = request.Messages.ToList(); - result.Add(generatedText.AsAiMessage()); + OnCompletedResponseGenerated(newMessage.Content); + } + else + { + var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) + .ConfigureAwait(false); + + var generatedText = response?["results"]?[0]?["outputText"]?.GetValue() ?? string.Empty; + + messages.Add(generatedText.AsAiMessage()); + OnCompletedResponseGenerated(generatedText); + } - // Unsupported var usage = Usage.Empty with { Time = watch.Elapsed, @@ -57,9 +85,25 @@ public override async Task GenerateAsync( return new ChatResponse { - Messages = result, + Messages = messages, UsedSettings = usedSettings, Usage = usage, }; } + + private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings) + { + var bodyJson = new JsonObject + { + ["inputText"] = prompt, + ["textGenerationConfig"] = new JsonObject + { + ["maxTokenCount"] = usedSettings.MaxTokens!.Value, + ["temperature"] = usedSettings.Temperature!.Value, + ["topP"] = usedSettings.TopP!.Value, + ["stopSequences"] = usedSettings.StopSequences!.AsArray() + } + }; + return bodyJson; + } } \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs index f68b070a..9d769de9 100644 --- a/src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs +++ b/src/Providers/Amazon.Bedrock/src/Chat/AnthropicClaudeChatModel.cs @@ -1,5 +1,8 @@ using System.Diagnostics; +using System.Text; +using System.Text.Json; using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; using LangChain.Providers.Amazon.Bedrock.Internal; // ReSharper disable once CheckNamespace @@ -19,31 +22,59 @@ public override async Task GenerateAsync( var watch = Stopwatch.StartNew(); var prompt = request.Messages.ToRolePrompt(); + var messages = request.Messages.ToList(); + + var stringBuilder = new StringBuilder(); var usedSettings = BedrockChatSettings.Calculate( requestSettings: settings, modelSettings: Settings, providerSettings: provider.ChatSettings); - var response = await provider.Api.InvokeModelAsync( - Id, - new JsonObject + + var bodyJson = CreateBodyJson(prompt, usedSettings); + + if (usedSettings.UseStreaming == true) + { + var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson); + var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken); + + foreach (var payloadPart in response.Body) { - ["prompt"] = prompt, - ["max_tokens_to_sample"] = usedSettings.MaxTokens!.Value, - ["temperature"] = usedSettings.Temperature!.Value, - ["top_p"] = usedSettings.TopP!.Value, - ["top_k"] = usedSettings.TopK!.Value, - ["stop_sequences"] = new JsonArray("\n\nHuman:") - }, - cancellationToken).ConfigureAwait(false); - - var generatedText = response?["completion"]? - .GetValue() ?? ""; - - var result = request.Messages.ToList(); - result.Add(generatedText.AsAiMessage()); - - // Unsupported + var streamEvent = (PayloadPart)payloadPart; + var chunk = await JsonSerializer.DeserializeAsync(streamEvent.Bytes, cancellationToken: cancellationToken) + .ConfigureAwait(false); + var delta = chunk?["completion"]!.GetValue(); + + OnPartialResponseGenerated(delta!); + stringBuilder.Append(delta); + + var finished = chunk?["completionReason"]?.GetValue(); + if (finished?.ToLower() == "finish") + { + OnCompletedResponseGenerated(stringBuilder.ToString()); + } + } + + OnPartialResponseGenerated(Environment.NewLine); + stringBuilder.Append(Environment.NewLine); + + var newMessage = new Message( + Content: stringBuilder.ToString(), + Role: MessageRole.Ai); + messages.Add(newMessage); + + OnCompletedResponseGenerated(newMessage.Content); + } + else + { + var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken).ConfigureAwait(false); + + var generatedText = response?["completion"]?.GetValue() ?? ""; + + messages.Add(generatedText.AsAiMessage()); + OnCompletedResponseGenerated(generatedText); + } + var usage = Usage.Empty with { Time = watch.Elapsed, @@ -53,9 +84,23 @@ public override async Task GenerateAsync( return new ChatResponse { - Messages = result, + Messages = messages, UsedSettings = usedSettings, Usage = usage, }; } + + private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings) + { + var bodyJson = new JsonObject + { + ["prompt"] = prompt, + ["max_tokens_to_sample"] = usedSettings.MaxTokens!.Value, + ["temperature"] = usedSettings.Temperature!.Value, + ["top_p"] = usedSettings.TopP!.Value, + ["top_k"] = usedSettings.TopK!.Value, + ["stop_sequences"] = new JsonArray("\n\nHuman:") + }; + return bodyJson; + } } \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Chat/BedrockChatSettings.cs b/src/Providers/Amazon.Bedrock/src/Chat/BedrockChatSettings.cs index e0881803..6df30c80 100644 --- a/src/Providers/Amazon.Bedrock/src/Chat/BedrockChatSettings.cs +++ b/src/Providers/Amazon.Bedrock/src/Chat/BedrockChatSettings.cs @@ -7,8 +7,9 @@ public class BedrockChatSettings : ChatSettings { StopSequences = ChatSettings.Default.StopSequences, User = ChatSettings.Default.User, + UseStreaming = false, Temperature = 0.7, - MaxTokens = 4096, + MaxTokens = 2048, TopP = 0.9, TopK = 0.0 }; @@ -66,6 +67,12 @@ public class BedrockChatSettings : ChatSettings providerSettingsCasted?.User ?? Default.User ?? throw new InvalidOperationException("Default User is not set."), + UseStreaming = + requestSettings?.UseStreaming ?? + modelSettings?.UseStreaming ?? + providerSettings?.UseStreaming ?? + Default.UseStreaming ?? + throw new InvalidOperationException("Default UseStreaming is not set."), Temperature = requestSettingsCasted?.Temperature ?? modelSettingsCasted?.Temperature ?? diff --git a/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs index 1a5e4f68..c502b9f6 100644 --- a/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs +++ b/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs @@ -1,5 +1,8 @@ using System.Diagnostics; +using System.Text; +using System.Text.Json; using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; using LangChain.Providers.Amazon.Bedrock.Internal; // ReSharper disable once CheckNamespace @@ -19,27 +22,49 @@ public override async Task GenerateAsync( var watch = Stopwatch.StartNew(); var prompt = request.Messages.ToSimplePrompt(); + var messages = request.Messages.ToList(); + + var stringBuilder = new StringBuilder(); var usedSettings = BedrockChatSettings.Calculate( requestSettings: settings, modelSettings: Settings, providerSettings: provider.ChatSettings); - var response = await provider.Api.InvokeModelAsync( - Id, - new JsonObject + + var bodyJson = CreateBodyJson(prompt, usedSettings); + + if (usedSettings.UseStreaming == true) + { + var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson); + var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken); + + foreach (var payloadPart in response.Body) { - ["prompt"] = prompt, - ["max_tokens"] = usedSettings.MaxTokens!.Value, - ["temperature"] = usedSettings.Temperature!.Value, - ["p"] = usedSettings.TopP!.Value, - ["k"] = usedSettings.TopK!.Value, - }, - cancellationToken).ConfigureAwait(false); + var streamEvent = (PayloadPart)payloadPart; + var chunk = await JsonSerializer.DeserializeAsync(streamEvent.Bytes, cancellationToken: cancellationToken) + .ConfigureAwait(false); + var delta = chunk?["generations"]?[0]?["text"]?.GetValue() ?? string.Empty; + + OnPartialResponseGenerated(delta!); + stringBuilder.Append(delta); - var generatedText = response?["generations"]?[0]?["text"]?.GetValue() ?? string.Empty; + var finished = chunk?["finish_reason"]?[0]?["text"]?.GetValue() ?? string.Empty; + if (string.Equals(finished?.ToLower(), "finish", StringComparison.Ordinal)) + { + OnCompletedResponseGenerated(stringBuilder.ToString()); + } + } + } + else + { + var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) + .ConfigureAwait(false); - var result = request.Messages.ToList(); - result.Add(generatedText.AsAiMessage()); + var generatedText = response?["generations"]?[0]?["text"]?.GetValue() ?? string.Empty; + + messages.Add(generatedText.AsAiMessage()); + OnCompletedResponseGenerated(generatedText); + } var usage = Usage.Empty with { @@ -50,9 +75,22 @@ public override async Task GenerateAsync( return new ChatResponse { - Messages = result, + Messages = messages, UsedSettings = usedSettings, Usage = usage, }; } + + private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings) + { + var bodyJson = new JsonObject + { + ["prompt"] = prompt, + ["max_tokens"] = usedSettings.MaxTokens!.Value, + ["temperature"] = usedSettings.Temperature!.Value, + ["p"] = usedSettings.TopP!.Value, + ["k"] = usedSettings.TopK!.Value, + }; + return bodyJson; + } } \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs index a5fe1594..9ad45231 100644 --- a/src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs +++ b/src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs @@ -1,5 +1,8 @@ using System.Diagnostics; +using System.Text; +using System.Text.Json; using System.Text.Json.Nodes; +using Amazon.BedrockRuntime.Model; using LangChain.Providers.Amazon.Bedrock.Internal; // ReSharper disable once CheckNamespace @@ -10,8 +13,6 @@ public class MetaLlama2ChatModel( string id) : ChatModel(id) { - public override int ContextLength => 4096; - public override async Task GenerateAsync( ChatRequest request, ChatSettings? settings = null, @@ -21,27 +22,50 @@ public override async Task GenerateAsync( var watch = Stopwatch.StartNew(); var prompt = request.Messages.ToSimplePrompt(); + var messages = request.Messages.ToList(); + + var stringBuilder = new StringBuilder(); var usedSettings = BedrockChatSettings.Calculate( requestSettings: settings, modelSettings: Settings, providerSettings: provider.ChatSettings); - var response = await provider.Api.InvokeModelAsync( - Id, - new JsonObject + + var bodyJson = CreateBodyJson(prompt, usedSettings); + + if (usedSettings.UseStreaming == true) + { + var streamRequest = BedrockModelStreamRequest.Create(Id, bodyJson); + var response = await provider.Api.InvokeModelWithResponseStreamAsync(streamRequest, cancellationToken); + + foreach (var payloadPart in response.Body) { - ["prompt"] = prompt, - ["max_gen_len"] = usedSettings.MaxTokens!.Value, - ["temperature"] = usedSettings.Temperature!.Value, - ["topP"] = usedSettings.TopP!.Value, - }, - cancellationToken).ConfigureAwait(false); + var streamEvent = (PayloadPart)payloadPart; + var chunk = await JsonSerializer.DeserializeAsync(streamEvent.Bytes, cancellationToken: cancellationToken) + .ConfigureAwait(false); + var delta = chunk?["generation"]?.GetValue() ?? string.Empty; + + OnPartialResponseGenerated(delta!); + stringBuilder.Append(delta); - var generatedText = response?["generation"]? - .GetValue() ?? string.Empty; + var finished = chunk?["stop_reason"]?.GetValue() ?? string.Empty; + if (string.Equals(finished?.ToLower(), "stop", StringComparison.Ordinal)) + { + OnCompletedResponseGenerated(stringBuilder.ToString()); + } + } + } + else + { + var response = await provider.Api.InvokeModelAsync(Id, bodyJson, cancellationToken) + .ConfigureAwait(false); + + var generatedText = response?["generation"]? + .GetValue() ?? string.Empty; - var result = request.Messages.ToList(); - result.Add(generatedText.AsAiMessage()); + messages.Add(generatedText.AsAiMessage()); + OnCompletedResponseGenerated(generatedText); + } var usage = Usage.Empty with { @@ -52,9 +76,20 @@ public override async Task GenerateAsync( return new ChatResponse { - Messages = result, + Messages = messages, UsedSettings = usedSettings, Usage = usage, }; } + + private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings usedSettings) + { + var bodyJson = new JsonObject + { + ["prompt"] = prompt, + ["max_gen_len"] = usedSettings.MaxTokens!.Value, + ["temperature"] = usedSettings.Temperature!.Value, + }; + return bodyJson; + } } \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs index f64d7dca..7357ad63 100644 --- a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs +++ b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs @@ -9,6 +9,7 @@ using LangChain.Providers.Amazon.Bedrock.Predefined.Ai21Labs; using LangChain.Providers.Amazon.Bedrock.Predefined.Amazon; using LangChain.Providers.Amazon.Bedrock.Predefined.Anthropic; +using LangChain.Providers.Amazon.Bedrock.Predefined.Cohere; using LangChain.Providers.Amazon.Bedrock.Predefined.Meta; using LangChain.Providers.Amazon.Bedrock.Predefined.Stability; using LangChain.Schema; @@ -199,9 +200,9 @@ public async Task SimpleRag() var chain = Set("what color is the car?", outputKey: "question") // set the question - //Set("Hagrid was looking for the golden key. Where was it?", outputKey: "question") // set the question - // Set("Who was on the Dursleys front step?", outputKey: "question") // set the question - // Set("Who was drinking a unicorn blood?", outputKey: "question") // set the question + //Set("Hagrid was looking for the golden key. Where was it?", outputKey: "question") // set the question + // Set("Who was on the Dursleys front step?", outputKey: "question") // set the question + // Set("Who was drinking a unicorn blood?", outputKey: "question") // set the question | RetrieveDocuments(index, inputKey: "question", outputKey: "documents", amount: 5) // take 5 most similar documents | StuffDocuments(inputKey: "documents", outputKey: "context") // combine documents together and put them into context | Template(promptText) // replace context and question in the prompt with their values @@ -220,12 +221,12 @@ public async Task CanGetImage() "create a picture of the solar system"); var path = Path.Combine(Path.GetTempPath(), "solar_system.png"); - + await File.WriteAllBytesAsync(path, response.Bytes); - + Process.Start(path); } - + [Test] public async Task CanGetImage2() { @@ -235,9 +236,38 @@ public async Task CanGetImage2() "i'm going to prepare a recipe. show me an image of realistic food ingredients"); var path = Path.Combine(Path.GetTempPath(), "food.png"); - + await File.WriteAllBytesAsync(path, response.Bytes); - + Process.Start(path); } + + [TestCase(true, false)] + [TestCase(false, false)] + [TestCase(true, true)] + [TestCase(false, true)] + public async Task SimpleTest(bool useStreaming, bool useChatSettings) + { + var provider = new BedrockProvider(); + var llm = new Llama2With13BModel(provider); + + llm.PromptSent += (_, prompt) => Console.WriteLine($"Prompt: {prompt}"); + llm.PartialResponseGenerated += (_, delta) => Console.Write(delta); + llm.CompletedResponseGenerated += (_, prompt) => Console.WriteLine($"Completed response: {prompt}"); + + var prompt = @" +you are a comic book writer. you will be given a question and you will answer it. +question: who are 10 of the most popular superheros and what are their powers?"; + + if (useChatSettings) + { + var response = await llm.GenerateAsync(prompt, new BedrockChatSettings { UseStreaming = useStreaming}); + response.LastMessageContent.Should().NotBeNull(); + } + else + { + var response = await llm.GenerateAsync(prompt); + response.LastMessageContent.Should().NotBeNull(); + } + } } \ No newline at end of file From 80d558ab865b3fff87a1a1bb7d1d38a18d1a049c Mon Sep 17 00:00:00 2001 From: Ty Augustine Date: Tue, 27 Feb 2024 22:27:58 -0500 Subject: [PATCH 2/4] fix: Cohere streaming stop_reason --- .../Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs | 4 ++-- src/Providers/Amazon.Bedrock/test/BedrockTests.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs index c502b9f6..3c088aba 100644 --- a/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs +++ b/src/Providers/Amazon.Bedrock/src/Chat/CohereCommandChatModel.cs @@ -48,8 +48,8 @@ public override async Task GenerateAsync( OnPartialResponseGenerated(delta!); stringBuilder.Append(delta); - var finished = chunk?["finish_reason"]?[0]?["text"]?.GetValue() ?? string.Empty; - if (string.Equals(finished?.ToLower(), "finish", StringComparison.Ordinal)) + var finished = chunk?["generations"]?[0]?["finish_reason"]?.GetValue() ?? string.Empty; + if (string.Equals(finished?.ToLower(), "complete", StringComparison.Ordinal)) { OnCompletedResponseGenerated(stringBuilder.ToString()); } diff --git a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs index 7357ad63..ac35c9b2 100644 --- a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs +++ b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs @@ -249,7 +249,7 @@ public async Task CanGetImage2() public async Task SimpleTest(bool useStreaming, bool useChatSettings) { var provider = new BedrockProvider(); - var llm = new Llama2With13BModel(provider); + var llm = new CommandLightTextV14Model(provider); llm.PromptSent += (_, prompt) => Console.WriteLine($"Prompt: {prompt}"); llm.PartialResponseGenerated += (_, delta) => Console.Write(delta); From 6126572643da562a9fc979b30d17a3e0dad96b0e Mon Sep 17 00:00:00 2001 From: Ty Augustine Date: Wed, 28 Feb 2024 07:56:36 -0500 Subject: [PATCH 3/4] fix: Amazon Titan Image Generator response added: default Bedrock Image Settings --- .../AmazonTitanImageGenerationModel.cs | 32 ++++---- .../ImageGeneration/BedrockImageSettings.cs | 78 +++++++++++++++++++ .../StableDiffusionImageGenerationModel.cs | 2 - .../Amazon.Bedrock/test/BedrockTests.cs | 2 +- 4 files changed, 97 insertions(+), 17 deletions(-) create mode 100644 src/Providers/Amazon.Bedrock/src/ImageGeneration/BedrockImageSettings.cs diff --git a/src/Providers/Amazon.Bedrock/src/ImageGeneration/AmazonTitanImageGenerationModel.cs b/src/Providers/Amazon.Bedrock/src/ImageGeneration/AmazonTitanImageGenerationModel.cs index e0bb99e5..40ea1bd6 100644 --- a/src/Providers/Amazon.Bedrock/src/ImageGeneration/AmazonTitanImageGenerationModel.cs +++ b/src/Providers/Amazon.Bedrock/src/ImageGeneration/AmazonTitanImageGenerationModel.cs @@ -16,8 +16,13 @@ public async Task GenerateImageAsync( CancellationToken cancellationToken = default) { request = request ?? throw new ArgumentNullException(nameof(request)); - + var watch = Stopwatch.StartNew(); + + var usedSettings = BedrockImageSettings.Calculate( + requestSettings: settings, + modelSettings: Settings, + providerSettings: provider.ImageGenerationSettings); var response = await provider.Api.InvokeModelAsync( Id, new JsonObject @@ -25,23 +30,22 @@ public async Task GenerateImageAsync( ["taskType"] = "TEXT_IMAGE", ["textToImageParams"] = new JsonObject { - ["text"] = request.Prompt, - ["imageGenerationConfig"] = new JsonObject - { - ["quality"] = "standard", - ["width"] = 1024, - ["height"] = 1024, - ["cfgScale"] = 8.0, - ["seed"] = 0, - ["numberOfImages"] = 3, - } + ["text"] = request.Prompt + }, + ["imageGenerationConfig"] = new JsonObject + { + ["quality"] = "standard", + ["width"] = usedSettings.Width!.Value, + ["height"] = usedSettings.Height!.Value, + ["cfgScale"] = 8.0, + ["seed"] = usedSettings.Seed!.Value, + ["numberOfImages"] = usedSettings.NumOfImages!.Value, } }, cancellationToken).ConfigureAwait(false); - var generatedText = response?["results"]?[0]?["outputText"]?.GetValue() ?? ""; - - // Unsupported + var generatedText = response?["images"]?[0]?.GetValue() ?? ""; + var usage = Usage.Empty with { Time = watch.Elapsed, diff --git a/src/Providers/Amazon.Bedrock/src/ImageGeneration/BedrockImageSettings.cs b/src/Providers/Amazon.Bedrock/src/ImageGeneration/BedrockImageSettings.cs new file mode 100644 index 00000000..cde5a7dc --- /dev/null +++ b/src/Providers/Amazon.Bedrock/src/ImageGeneration/BedrockImageSettings.cs @@ -0,0 +1,78 @@ +// ReSharper disable once CheckNamespace +namespace LangChain.Providers.Amazon.Bedrock; + +public class BedrockImageSettings : ImageGenerationSettings +{ + public new static BedrockImageSettings Default { get; } = new() + { + Height = 1024, + Width = 1024, + Seed = 0, + NumOfImages = 1, + }; + + /// + /// + /// + public int? Height { get; init; } + /// + /// + /// + public int? Width { get; init; } + + /// + /// + /// + public int? Seed { get; init; } + + /// + /// + /// + public int? NumOfImages { get; init; } + + /// + /// Calculate the settings to use for the request. + /// + /// + /// + /// + /// + /// + public new static BedrockImageSettings Calculate( + ImageGenerationSettings? requestSettings, + ImageGenerationSettings? modelSettings, + ImageGenerationSettings? providerSettings) + { + var requestSettingsCasted = requestSettings as BedrockImageSettings; + var modelSettingsCasted = modelSettings as BedrockImageSettings; + var providerSettingsCasted = providerSettings as BedrockImageSettings; + + return new BedrockImageSettings + { + Height = + requestSettingsCasted?.Height ?? + modelSettingsCasted?.Height ?? + providerSettingsCasted?.Height ?? + Default.Height ?? + throw new InvalidOperationException("Default Height is not set."), + Width = + requestSettingsCasted?.Width ?? + modelSettingsCasted?.Width ?? + providerSettingsCasted?.Width ?? + Default.Width ?? + throw new InvalidOperationException("Default Width is not set."), + Seed = + requestSettingsCasted?.Seed ?? + modelSettingsCasted?.Seed ?? + providerSettingsCasted?.Seed ?? + Default.Seed ?? + throw new InvalidOperationException("Default Seed is not set."), + NumOfImages = + requestSettingsCasted?.NumOfImages ?? + modelSettingsCasted?.NumOfImages ?? + providerSettingsCasted?.NumOfImages ?? + Default.NumOfImages ?? + throw new InvalidOperationException("Default NumOfImages is not set."), + }; + } +} \ No newline at end of file diff --git a/src/Providers/Amazon.Bedrock/src/ImageGeneration/StableDiffusionImageGenerationModel.cs b/src/Providers/Amazon.Bedrock/src/ImageGeneration/StableDiffusionImageGenerationModel.cs index 7422d1cb..545224c2 100644 --- a/src/Providers/Amazon.Bedrock/src/ImageGeneration/StableDiffusionImageGenerationModel.cs +++ b/src/Providers/Amazon.Bedrock/src/ImageGeneration/StableDiffusionImageGenerationModel.cs @@ -33,9 +33,7 @@ public async Task GenerateImageAsync( var base64 = response?["artifacts"]?[0]?["base64"]? .GetValue() ?? string.Empty; - //var generatedText = $"data:image/jpeg;base64,{body}"; - // Unsupported var usage = Usage.Empty with { Time = watch.Elapsed, diff --git a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs index ac35c9b2..8e741531 100644 --- a/src/Providers/Amazon.Bedrock/test/BedrockTests.cs +++ b/src/Providers/Amazon.Bedrock/test/BedrockTests.cs @@ -216,7 +216,7 @@ public async Task SimpleRag() public async Task CanGetImage() { var provider = new BedrockProvider(); - var model = new StableDiffusionExtraLargeV0Model(provider); + var model = new TitanImageGeneratorV1Model(provider); var response = await model.GenerateImageAsync( "create a picture of the solar system"); From 2e3c51358e28b3fcd68f6b91d145787c2153b605 Mon Sep 17 00:00:00 2001 From: Ty Augustine Date: Wed, 28 Feb 2024 09:16:58 -0500 Subject: [PATCH 4/4] fix: re-added top_p to LLama2 model request --- src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs b/src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs index 9ad45231..94ade436 100644 --- a/src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs +++ b/src/Providers/Amazon.Bedrock/src/Chat/MetaLlama2ChatModel.cs @@ -89,6 +89,7 @@ private static JsonObject CreateBodyJson(string prompt, BedrockChatSettings used ["prompt"] = prompt, ["max_gen_len"] = usedSettings.MaxTokens!.Value, ["temperature"] = usedSettings.Temperature!.Value, + ["top_p"] = usedSettings.TopP!.Value, }; return bodyJson; }