Skip to content

Commit

Permalink
feat: Amazon Nova support via Bedrock (#1548)
Browse files Browse the repository at this point in the history
  • Loading branch information
parhammmm authored Dec 11, 2024
1 parent 086a651 commit c1850ee
Show file tree
Hide file tree
Showing 13 changed files with 995 additions and 188 deletions.
6 changes: 6 additions & 0 deletions .changeset/popular-scissors-switch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@llamaindex/community": patch
"docs": patch
---

feat: Amazon Nova support via Bedrock
6 changes: 6 additions & 0 deletions apps/docs/docs/modules/llms/available_llms/bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ META_LLAMA3_2_1B_INSTRUCT = "meta.llama3-2-1b-instruct-v1:0"; // only available
META_LLAMA3_2_3B_INSTRUCT = "meta.llama3-2-3b-instruct-v1:0"; // only available via inference endpoints (see below)
META_LLAMA3_2_11B_INSTRUCT = "meta.llama3-2-11b-instruct-v1:0"; // only available via inference endpoints (see below), multimodal and function call supported
META_LLAMA3_2_90B_INSTRUCT = "meta.llama3-2-90b-instruct-v1:0"; // only available via inference endpoints (see below), multimodal and function call supported
AMAZON_NOVA_PRO_1 = "amazon.nova-pro-v1:0";
AMAZON_NOVA_LITE_1 = "amazon.nova-lite-v1:0";
AMAZON_NOVA_MICRO_1 = "amazon.nova-micro-v1:0";
```

You can also use Bedrock's Inference endpoints by using the model names:
Expand All @@ -53,6 +56,9 @@ US_META_LLAMA_3_2_1B_INSTRUCT = "us.meta.llama3-2-1b-instruct-v1:0";
US_META_LLAMA_3_2_3B_INSTRUCT = "us.meta.llama3-2-3b-instruct-v1:0";
US_META_LLAMA_3_2_11B_INSTRUCT = "us.meta.llama3-2-11b-instruct-v1:0";
US_META_LLAMA_3_2_90B_INSTRUCT = "us.meta.llama3-2-90b-instruct-v1:0";
US_AMAZON_NOVA_PRO_1 = "us.amazon.nova-pro-v1:0";
US_AMAZON_NOVA_LITE_1 = "us.amazon.nova-lite-v1:0";
US_AMAZON_NOVA_MICRO_1 = "us.amazon.nova-micro-v1:0";

// EU
EU_ANTHROPIC_CLAUDE_3_HAIKU = "eu.anthropic.claude-3-haiku-20240307-v1:0";
Expand Down
1 change: 1 addition & 0 deletions packages/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
## Current Features:

- Bedrock support for Amazon Nova models Pro, Lite and Micro
- Bedrock support for the Anthropic Claude Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock) including the latest Sonnet 3.5 v2 and Haiku 3.5
- Bedrock support for the Meta LLama 2, 3, 3.1 and 3.2 Models [usage](https://ts.llamaindex.ai/modules/llms/available_llms/bedrock)
- Meta LLama3.1 405b and Llama3.2 tool call support
Expand Down
4 changes: 2 additions & 2 deletions packages/community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
"bunchee": "5.6.1"
},
"dependencies": {
"@aws-sdk/client-bedrock-agent-runtime": "^3.693.0",
"@aws-sdk/client-bedrock-runtime": "^3.693.0",
"@aws-sdk/client-bedrock-agent-runtime": "^3.706.0",
"@aws-sdk/client-bedrock-runtime": "^3.706.0",
"@llamaindex/core": "workspace:*",
"@llamaindex/env": "workspace:*"
}
Expand Down
133 changes: 133 additions & 0 deletions packages/community/src/llm/bedrock/amazon/provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import type {
ContentBlockDelta,
ConverseOutput,
ConverseRequest,
ConverseResponse,
ConverseStreamOutput,
InvokeModelCommandInput,
InvokeModelWithResponseStreamCommandInput,
ResponseStream,
} from "@aws-sdk/client-bedrock-runtime";
import type {
BaseTool,
ChatMessage,
LLMMetadata,
ToolCall,
ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms";
import { toUtf8 } from "../utils";

import { Provider, type BedrockChatStreamResponse } from "../provider";
import {
mapBaseToolsToAmazonTools,
mapChatMessagesToAmazonMessages,
} from "./utils";

export class AmazonProvider extends Provider<ConverseStreamOutput> {
getResultFromResponse(response: Record<string, any>): ConverseResponse {
return JSON.parse(toUtf8(response.body));
}

getToolsFromResponse<ToolContent>(response: ConverseOutput): ToolContent[] {
return (
response.message?.content
?.filter((item) => item.toolUse)
.map(
(item) =>
({
id: item.toolUse!.toolUseId,
name: item.toolUse!.name,
input: item.toolUse!.input
? JSON.parse(item.toolUse!.input as string)
: "",
}) as ToolContent,
) ?? []
);
}

getTextFromResponse(response: ConverseResponse): string {
const result = this.getResultFromResponse(response);
const content = result.output?.message?.content ?? [];
return content.map((item) => item.text).join(" ");
}

getTextFromStreamResponse(response: ResponseStream): string {
let event: ConverseStreamOutput | undefined =
this.getStreamingEventResponse(response);
if (!event || !event.contentBlockDelta) return "";
const delta: ContentBlockDelta | undefined = event.contentBlockDelta.delta;
return delta?.text || "";
}

async *reduceStream(
stream: AsyncIterable<ResponseStream>,
): BedrockChatStreamResponse {
let toolId: string | undefined = undefined;
let toolName: string | undefined = undefined;
for await (const response of stream) {
const event = this.getStreamingEventResponse(response);
const delta = this.getTextFromStreamResponse(response);

let options: undefined | ToolCallLLMMessageOptions = undefined;
if (event?.contentBlockStart && event.contentBlockStart.start?.toolUse) {
toolId = event.contentBlockStart.start?.toolUse.toolUseId;
toolName = event.contentBlockStart.start?.toolUse.name;
continue;
}
if (
toolId &&
toolName &&
event?.contentBlockDelta?.delta?.toolUse?.input
) {
options = {
toolCall: [
{
id: toolId,
name: toolName,
input: JSON.parse(event?.contentBlockDelta?.delta?.toolUse.input),
} as ToolCall,
],
};
toolId = undefined;
toolName = undefined;
}

if (!delta && !options) continue;

yield {
delta: options ? "" : delta,
options,
raw: response,
};
}
}

getRequestBody<T extends ChatMessage>(
metadata: LLMMetadata,
messages: T[],
tools: BaseTool[] = [],
options: Omit<ConverseRequest, "modelId" | "messages" | "inferenceConfig">,
): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput {
const request: Omit<ConverseRequest, "modelId"> = {
...options,
messages: mapChatMessagesToAmazonMessages(messages),
inferenceConfig: {
maxTokens: metadata.maxTokens,
temperature: metadata.temperature,
topP: metadata.topP,
},
};
if (tools.length) {
request.toolConfig = {
tools: mapBaseToolsToAmazonTools(tools),
};
}

return {
modelId: metadata.model,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify(request),
};
}
}
5 changes: 5 additions & 0 deletions packages/community/src/llm/bedrock/amazon/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import type { ConverseRequest, Message } from "@aws-sdk/client-bedrock-runtime";

export type AmazonMessages = ConverseRequest["messages"];

export type AmazonMessage = Message;
141 changes: 141 additions & 0 deletions packages/community/src/llm/bedrock/amazon/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import type {
ImageBlock,
ImageFormat,
Message,
Tool,
} from "@aws-sdk/client-bedrock-runtime";
import type {
BaseTool,
ChatMessage,
MessageContentDetail,
ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms";
import {
extractDataUrlComponents,
mapMessageContentToMessageContentDetails,
} from "../utils";

import type { JSONObject } from "@llamaindex/core/global";
import type { AmazonMessage, AmazonMessages } from "./types";

const ACCEPTED_IMAGE_MIME_TYPES = [
"image/jpeg",
"image/png",
"image/webp",
"image/gif",
] as const;

const ACCEPTED_IMAGE_MIME_TYPE_FORMAT_MAP: Record<
(typeof ACCEPTED_IMAGE_MIME_TYPES)[number],
ImageFormat
> = {
"image/jpeg": "jpeg",
"image/png": "png",
"image/webp": "webp",
"image/gif": "gif",
};

export const mapImageContent = (imageUrl: string): ImageBlock => {
if (!imageUrl.startsWith("data:"))
throw new Error(
"For Amazon please only use base64 data url, e.g.: data:image/jpeg;base64,SGVsbG8sIFdvcmxkIQ==",
);
const { mimeType, base64: data } = extractDataUrlComponents(imageUrl);
if (
!ACCEPTED_IMAGE_MIME_TYPES.includes(
mimeType as keyof typeof ACCEPTED_IMAGE_MIME_TYPE_FORMAT_MAP,
)
)
throw new Error(
`Amazon only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join("\n")}`,
);

return {
format:
ACCEPTED_IMAGE_MIME_TYPE_FORMAT_MAP[
mimeType as keyof typeof ACCEPTED_IMAGE_MIME_TYPE_FORMAT_MAP
],

// @ts-ignore: there's a mistake in the "@aws-sdk/client-bedrock-runtime" compared to the actual api
source: { bytes: data },
};
};

export const mapMessageContentDetailToAmazonContent = <
T extends MessageContentDetail,
>(
detail: T,
): Message["content"] => {
let content: Message["content"] = [];

if (detail.type === "text") {
content = [{ text: detail.text }];
} else if (detail.type === "image_url") {
content = [{ image: mapImageContent(detail.image_url.url) }];
} else {
throw new Error("Unsupported content detail type");
}
return content;
};

export const mapChatMessagesToAmazonMessages = <
T extends ChatMessage<ToolCallLLMMessageOptions>,
>(
messages: T[],
): AmazonMessages => {
return messages.flatMap((msg: T): AmazonMessage[] => {
return mapMessageContentToMessageContentDetails(msg.content).map(
(detail: MessageContentDetail): AmazonMessage => {
if (msg.options && "toolCall" in msg.options) {
return {
role: "assistant",
content: msg.options.toolCall.map((call) => ({
toolUse: {
toolUseId: call.id,
name: call.name,
input: call.input as JSONObject,
},
})),
};
}
if (msg.options && "toolResult" in msg.options) {
return {
role: "user",
content: [
{
toolResult: {
toolUseId: msg.options.toolResult.id,
content: [
{
text: msg.options.toolResult.result,
},
],
},
},
],
};
}

return {
role: msg.role === "assistant" ? "assistant" : "user",
content: mapMessageContentDetailToAmazonContent(detail),
};
},
);
});
};

export const mapBaseToolsToAmazonTools = (tools?: BaseTool[]): Tool[] => {
if (!tools) return [];
return tools.map((tool: BaseTool) => {
const {
metadata: { parameters, ...options },
} = tool;
return {
toolSpec: {
...options,
inputSchema: parameters,
},
} as Tool;
});
};
9 changes: 3 additions & 6 deletions packages/community/src/llm/bedrock/anthropic/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@ import type {
ToolCall,
ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms";
import {
type BedrockAdditionalChatOptions,
type BedrockChatStreamResponse,
Provider,
} from "../provider";
import { type BedrockChatStreamResponse, Provider } from "../provider";
import { toUtf8 } from "../utils";
import type {
AnthropicAdditionalChatOptions,
AnthropicNoneStreamingResponse,
AnthropicStreamEvent,
AnthropicTextContent,
Expand Down Expand Up @@ -134,7 +131,7 @@ export class AnthropicProvider extends Provider<AnthropicStreamEvent> {
metadata: LLMMetadata,
messages: T[],
tools?: BaseTool[],
options?: BedrockAdditionalChatOptions,
options?: AnthropicAdditionalChatOptions,
): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput {
const extra: Record<string, unknown> = {};
if (options?.toolChoice) {
Expand Down
7 changes: 7 additions & 0 deletions packages/community/src/llm/bedrock/anthropic/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import type { ToolMetadata } from "@llamaindex/core/llms";
import type { InvocationMetrics } from "../types";

export type ToolChoice =
| { type: "any" }
| { type: "auto" }
| { type: "tool"; name: string };

export type AnthropicAdditionalChatOptions = { toolChoice: ToolChoice };

type Usage = {
input_tokens: number;
output_tokens: number;
Expand Down
Loading

0 comments on commit c1850ee

Please sign in to comment.