Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions config/test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ config :llm_db,
family: "command",
capabilities: %{chat: true},
limits: %{context: 4096, output: 4096}
},
"cohere.embed-english-v3" => %{
name: "Cohere Embed English v3",
family: "embed",
capabilities: %{embeddings: true}
}
}
],
Expand Down Expand Up @@ -318,6 +323,7 @@ config :req_llm, :sample_embedding_models, ~w(
openai:text-embedding-3-small
google:text-embedding-004
azure:text-embedding-3-small
amazon_bedrock:cohere.embed-english-v3
)
config :req_llm, :sample_text_models, ~w(
anthropic:claude-3-5-haiku-20241022
Expand Down
197 changes: 179 additions & 18 deletions lib/req_llm/providers/amazon_bedrock.ex
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ defmodule ReqLLM.Providers.AmazonBedrock do
default: "default",
doc:
"Service tier for request prioritization. Priority provides faster responses at higher cost, Flex is more cost-effective with longer latency."
],
input_type: [
type: {:in, ["search_document", "search_query", "classification", "clustering"]},
default: "search_document",
doc: "Input type for Cohere embedding models"
],
embedding_types: [
type: {:list, {:in, ["float", "int8", "uint8", "binary", "ubinary"]}},
default: ["float"],
doc: "Output formats for Cohere embeddings"
],
truncate: [
type: {:in, ["NONE", "LEFT", "RIGHT"]},
default: "NONE",
doc: "Truncation strategy for Cohere embedding models"
],
images: [
type: {:list, :string},
doc: "List of base64-encoded images for Cohere image embeddings"
],
inputs: [
type: {:list, :map},
doc: "List of mixed content parts for Cohere interleaved embeddings"
]
]

Expand All @@ -188,6 +211,10 @@ defmodule ReqLLM.Providers.AmazonBedrock do
"meta" => ReqLLM.Providers.AmazonBedrock.Meta
}

@embedding_families %{
"cohere" => ReqLLM.Providers.AmazonBedrock.Cohere
}

def default_base_url do
# Override to handle region template
"https://bedrock-runtime.{region}.amazonaws.com"
Expand Down Expand Up @@ -248,11 +275,35 @@ defmodule ReqLLM.Providers.AmazonBedrock do
end
end

@impl ReqLLM.Provider
def prepare_request(:embedding, model_input, text, opts) do
with {:ok, model} <- ReqLLM.model(model_input) do
http_opts = Keyword.get(opts, :req_http_options, [])
model_id = model.provider_model_id || model.id

timeout =
Keyword.get(
opts,
:receive_timeout,
Application.get_env(:req_llm, :receive_timeout, 30_000)
)

request =
Req.new(
[url: "/model/#{model_id}/invoke", method: :post, receive_timeout: timeout] ++
http_opts
)
|> attach_embedding(model, Keyword.put(opts, :text, text))

{:ok, request}
end
end

def prepare_request(operation, _model, _input, _opts) do
{:error,
InvalidParameter.exception(
parameter:
"operation: #{inspect(operation)} not supported by Bedrock provider. Supported operations: [:chat, :object]"
"operation: #{inspect(operation)} not supported by Bedrock provider. Supported operations: [:chat, :object, :embedding]"
)}
end

Expand Down Expand Up @@ -293,15 +344,7 @@ defmodule ReqLLM.Providers.AmazonBedrock do
operation
)

# Construct the base URL with region
region =
case aws_creds do
%{region: r} when is_binary(r) -> r
%{region: _} -> "us-east-1"
%AWSAuth.Credentials{region: r} when is_binary(r) -> r
%AWSAuth.Credentials{} -> "us-east-1"
_ -> "us-east-1"
end
region = extract_region(aws_creds)

base_url = "https://bedrock-runtime.#{region}.amazonaws.com"

Expand Down Expand Up @@ -399,6 +442,63 @@ defmodule ReqLLM.Providers.AmazonBedrock do
|> ReqLLM.Step.Fixture.maybe_attach(model, user_opts)
end

def attach_embedding(%Req.Request{} = request, model_input, user_opts) do
%LLMDB.Model{} =
model =
case ReqLLM.model(model_input) do
{:ok, m} -> m
{:error, err} -> raise err
end

if model.provider != provider_id() do
raise Error.Invalid.Provider.exception(provider: model.provider)
end

{aws_creds, other_opts} = extract_aws_credentials(user_opts)
validate_aws_credentials!(aws_creds)

processed_opts =
case ReqLLM.Provider.Options.process(__MODULE__, :embedding, model, other_opts) do
{:ok, opts} -> opts
{:error, error} -> raise error
end

region = extract_region(aws_creds)

base_url = "https://bedrock-runtime.#{region}.amazonaws.com"
model_id = model.provider_model_id || model.id
{model_family, formatter} = get_embedding_formatter(model_id)

text = processed_opts[:text]

case formatter.format_embedding_request(model_id, text, processed_opts) do
{:ok, model_body} ->
updated_request =
request
|> Map.put(:url, URI.parse(base_url <> "/model/#{model_id}/invoke"))
|> Req.Request.register_options([:model, :text, :operation, :model_family])
|> Req.Request.merge_options(
base_url: base_url,
model: model_id,
operation: :embedding,
model_family: model_family
)
|> Req.Request.put_header("content-type", "application/json")
|> Req.Request.put_private(:req_llm_model, model)
|> Map.put(:body, Jason.encode!(model_body))

updated_request
|> Step.Error.attach()
|> ReqLLM.Step.Retry.attach()
|> put_aws_sigv4(aws_creds)
|> Req.Request.append_response_steps(llm_decode_embedding: &decode_embedding_response/1)
|> ReqLLM.Step.Fixture.maybe_attach(model, user_opts)

{:error, error} ->
raise error
end
end

@impl ReqLLM.Provider
def attach_stream(model, context, opts, _finch_name) do
# Get AWS credentials
Expand Down Expand Up @@ -713,6 +813,21 @@ defmodule ReqLLM.Providers.AmazonBedrock do

defp validate_aws_credentials!(%AWSAuth.Credentials{}), do: :ok

defp extract_region(aws_creds) do
case aws_creds do
%{region: r} when is_binary(r) -> r
%AWSAuth.Credentials{region: r} when is_binary(r) -> r
_ -> "us-east-1"
end
end

defp strip_region_prefix(model_id) do
case String.split(model_id, ".", parts: 2) do
[region, rest] when region in ["us", "eu", "ap", "ca", "global"] -> rest
_ -> model_id
end
end

# API Key authentication - use Bearer token
defp put_aws_sigv4(request, %{api_key: api_key}) when is_binary(api_key) do
Req.Request.put_header(request, "authorization", "Bearer #{api_key}")
Expand Down Expand Up @@ -800,14 +915,7 @@ defmodule ReqLLM.Providers.AmazonBedrock do
end

defp get_model_family(model_id) do
normalized_id =
case String.split(model_id, ".", parts: 2) do
[possible_region, rest] when possible_region in ["us", "eu", "ap", "ca", "global"] ->
rest

_ ->
model_id
end
normalized_id = strip_region_prefix(model_id)

found_family =
@model_families
Expand All @@ -830,6 +938,30 @@ defmodule ReqLLM.Providers.AmazonBedrock do
"""
end

defp get_embedding_formatter(model_id) do
normalized_id = strip_region_prefix(model_id)

result =
@embedding_families
|> Enum.find(fn {prefix, _module} ->
String.starts_with?(normalized_id, prefix <> ".")
end)

case result do
{family, formatter} ->
{family, formatter}

nil ->
supported = Map.keys(@embedding_families) |> Enum.join(", ")

raise InvalidParameter.exception(
parameter:
"Embedding not supported for model: #{model_id}. " <>
"Supported embedding model families: #{supported}"
)
end
end

@impl ReqLLM.Provider
def translate_options(operation, model, opts) do
# Delegate to native Anthropic option translation for Anthropic models
Expand Down Expand Up @@ -953,6 +1085,35 @@ defmodule ReqLLM.Providers.AmazonBedrock do
{req, err}
end

defp decode_embedding_response({req, %{status: 200} = resp}) do
if req.private[:llm_fixture_replay] do
{req, resp}
else
parsed_body = ensure_parsed_body(resp.body)
model_family = req.options[:model_family]
formatter = Map.get(@embedding_families, model_family)

case formatter.parse_embedding_response(parsed_body) do
{:ok, normalized_response} ->
{req, %{resp | body: normalized_response}}

{:error, error} ->
{req, error}
end
end
end

defp decode_embedding_response({req, resp}) do
err =
Error.API.Response.exception(
reason: "Bedrock embedding API error",
status: resp.status,
response_body: resp.body
)

{req, err}
end

@impl ReqLLM.Provider
def thinking_constraints do
# AWS Bedrock requires temperature=1.0 when extended thinking is enabled
Expand Down
Loading