diff --git a/lib/langchain/llm/aws_bedrock.rb b/lib/langchain/llm/aws_bedrock.rb index 3adf28051..182a466bf 100644 --- a/lib/langchain/llm/aws_bedrock.rb +++ b/lib/langchain/llm/aws_bedrock.rb @@ -32,6 +32,7 @@ class AwsBedrock < Base ].freeze SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[ + amazon anthropic ai21 mistral @@ -216,6 +217,8 @@ def compose_parameters(params, model_id) params elsif provider_name(model_id) == :mistral params + elsif provider_name(model_id) == :amazon + compose_parameters_amazon(params) end end @@ -238,6 +241,8 @@ def parse_response(response, model_id) Langchain::LLM::AwsBedrockMetaResponse.new(JSON.parse(response.body.string)) elsif provider_name(model_id) == :mistral Langchain::LLM::MistralAIResponse.new(JSON.parse(response.body.string)) + elsif provider_name(model_id) == :amazon + Langchain::LLM::AwsBedrockAmazonResponse.new(JSON.parse(response.body.string)) end end @@ -288,6 +293,18 @@ def compose_parameters_anthropic(params) params.merge(anthropic_version: "bedrock-2023-05-31") end + def compose_parameters_amazon(params) + params = params.merge(inferenceConfig: { + maxTokens: params[:max_tokens], + temperature: params[:temperature], + topP: params[:top_p], + topK: params[:top_k], + stopSequences: params[:stop_sequences] + }.compact) + + params.reject { |k, _| k == :max_tokens || k == :temperature } + end + def response_from_chunks(chunks) raw_response = {} diff --git a/lib/langchain/llm/response/aws_bedrock_amazon_response.rb b/lib/langchain/llm/response/aws_bedrock_amazon_response.rb new file mode 100644 index 000000000..c2c9d5769 --- /dev/null +++ b/lib/langchain/llm/response/aws_bedrock_amazon_response.rb @@ -0,0 +1,37 @@ +# frozen_string_literal: true + +module Langchain::LLM + class AwsBedrockAmazonResponse < BaseResponse + def completion + raw_response.dig("output", "message", "content", 0, "text") + end + + def chat_completion + completion + end + + def chat_completions + completions + end + + def completions + nil + end + + def stop_reason + raw_response.dig("stopReason") + end + + def prompt_tokens + raw_response.dig("usage", "inputTokens").to_i + end + + def completion_tokens + raw_response.dig("usage", "outputTokens").to_i + end + + def total_tokens + raw_response.dig("usage", "totalTokens").to_i + end + end +end diff --git a/spec/langchain/llm/aws_bedrock_spec.rb b/spec/langchain/llm/aws_bedrock_spec.rb index 268e241f9..b497e7eb0 100644 --- a/spec/langchain/llm/aws_bedrock_spec.rb +++ b/spec/langchain/llm/aws_bedrock_spec.rb @@ -119,6 +119,50 @@ end end end + + context "with amazon provider" do + let(:response) do + { + output: { + message: { + content: [ + {text: "The capital of France is Paris."} + ] + } + }, + usage: {inputTokens: 14, outputTokens: 10} + }.to_json + end + + let(:model_id) { "amazon.nova-pro-v1:0" } + + before do + response_object = double("response_object") + allow(response_object).to receive(:body).and_return(StringIO.new(response)) + allow(subject.client).to receive(:invoke_model) + .with(matching( + model_id:, + body: { + messages: [{role: "user", content: [{text: "What is the capital of France?"}]}], + inferenceConfig: { + maxTokens: 300 + } + }.to_json, + content_type: "application/json", + accept: "application/json" + )) + .and_return(response_object) + end + + it "returns a completion" do + expect( + subject.chat( + messages: [{role: "user", content: [{text: "What is the capital of France?"}]}], + model: model_id + ).chat_completion + ).to eq("The capital of France is Paris.") + end + end end describe "#complete" do