From 0895d596f3b70b85b7c1b165df08ff644d9e7fba Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Fri, 9 Aug 2024 11:39:01 +0200 Subject: [PATCH 1/3] wip --- .../amazon-bedrock-anthropic.ts | 34 ++ .../generate-text/amazon-bedrock-anthropic.ts | 21 + .../anthropic/anthropic-messages-prompt.ts | 46 ++ ...drock-anthropic-messages-language-model.ts | 510 ++++++++++++++++++ .../convert-to-anthropic-messages-prompt.ts | 219 ++++++++ .../anthropic/map-anthropic-stop-reason.ts | 17 + .../amazon-bedrock/src/bedrock-provider.ts | 10 + 7 files changed, 857 insertions(+) create mode 100644 examples/ai-core/src/generate-object/amazon-bedrock-anthropic.ts create mode 100644 examples/ai-core/src/generate-text/amazon-bedrock-anthropic.ts create mode 100644 packages/amazon-bedrock/src/anthropic/anthropic-messages-prompt.ts create mode 100644 packages/amazon-bedrock/src/anthropic/bedrock-anthropic-messages-language-model.ts create mode 100644 packages/amazon-bedrock/src/anthropic/convert-to-anthropic-messages-prompt.ts create mode 100644 packages/amazon-bedrock/src/anthropic/map-anthropic-stop-reason.ts diff --git a/examples/ai-core/src/generate-object/amazon-bedrock-anthropic.ts b/examples/ai-core/src/generate-object/amazon-bedrock-anthropic.ts new file mode 100644 index 000000000000..b143c67d1109 --- /dev/null +++ b/examples/ai-core/src/generate-object/amazon-bedrock-anthropic.ts @@ -0,0 +1,34 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { generateObject } from 'ai'; +import dotenv from 'dotenv'; +import { z } from 'zod'; + +dotenv.config(); + +async function main() { + const result = await generateObject({ + model: bedrock.anthropicLanguageModel( + 'anthropic.claude-3-5-sonnet-20240620-v1:0', + ), + schema: z.object({ + recipe: z.object({ + name: z.string(), + ingredients: z.array( + z.object({ + name: z.string(), + amount: z.string(), + }), + ), + steps: z.array(z.string()), + }), + }), + prompt: 'Generate a lasagna recipe.', + }); + + console.log(JSON.stringify(result.object, null, 2)); + console.log(); + console.log('Token usage:', result.usage); + console.log('Finish reason:', result.finishReason); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/generate-text/amazon-bedrock-anthropic.ts b/examples/ai-core/src/generate-text/amazon-bedrock-anthropic.ts new file mode 100644 index 000000000000..a10fb7d0e436 --- /dev/null +++ b/examples/ai-core/src/generate-text/amazon-bedrock-anthropic.ts @@ -0,0 +1,21 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { generateText } from 'ai'; +import dotenv from 'dotenv'; + +dotenv.config(); + +async function main() { + const result = await generateText({ + model: bedrock.anthropicLanguageModel( + 'anthropic.claude-3-haiku-20240307-v1:0', + ), + prompt: 'Invent a new holiday and describe its traditions.', + }); + + console.log(result.text); + console.log(); + console.log('Token usage:', result.usage); + console.log('Finish reason:', result.finishReason); +} + +main().catch(console.error); diff --git a/packages/amazon-bedrock/src/anthropic/anthropic-messages-prompt.ts b/packages/amazon-bedrock/src/anthropic/anthropic-messages-prompt.ts new file mode 100644 index 000000000000..b22f52c3226a --- /dev/null +++ b/packages/amazon-bedrock/src/anthropic/anthropic-messages-prompt.ts @@ -0,0 +1,46 @@ +export type AnthropicMessagesPrompt = { + system?: string; + messages: AnthropicMessage[]; +}; + +export type AnthropicMessage = AnthropicUserMessage | AnthropicAssistantMessage; + +export interface AnthropicUserMessage { + role: 'user'; + content: Array< + AnthropicTextContent | AnthropicImageContent | AnthropicToolResultContent + >; +} + +export interface AnthropicAssistantMessage { + role: 'assistant'; + content: Array; +} + +export interface AnthropicTextContent { + type: 'text'; + text: string; +} + +export interface AnthropicImageContent { + type: 'image'; + source: { + type: 'base64'; + media_type: string; + data: string; + }; +} + +export interface AnthropicToolCallContent { + type: 'tool_use'; + id: string; + name: string; + input: unknown; +} + +export interface AnthropicToolResultContent { + type: 'tool_result'; + tool_use_id: string; + content: unknown; + is_error?: boolean; +} diff --git a/packages/amazon-bedrock/src/anthropic/bedrock-anthropic-messages-language-model.ts b/packages/amazon-bedrock/src/anthropic/bedrock-anthropic-messages-language-model.ts new file mode 100644 index 000000000000..0024ada73b23 --- /dev/null +++ b/packages/amazon-bedrock/src/anthropic/bedrock-anthropic-messages-language-model.ts @@ -0,0 +1,510 @@ +import { + LanguageModelV1, + LanguageModelV1CallWarning, + LanguageModelV1FunctionToolCall, + UnsupportedFunctionalityError, +} from '@ai-sdk/provider'; +import { + BedrockRuntimeClient, + InvokeModelCommand, +} from '@aws-sdk/client-bedrock-runtime'; +import { z } from 'zod'; +import { convertToAnthropicMessagesPrompt } from './convert-to-anthropic-messages-prompt'; +import { safeParseJSON } from '@ai-sdk/provider-utils'; +import { mapAnthropicStopReason } from './map-anthropic-stop-reason'; + +type AnthropicMessagesConfig = { + client: BedrockRuntimeClient; +}; + +export class BedrockAnthropicMessagesLanguageModel implements LanguageModelV1 { + readonly provider = 'amazon-bedrock-anthropic-messages'; + readonly specificationVersion = 'v1'; + readonly defaultObjectGenerationMode = 'tool'; + readonly supportsImageUrls = false; + + readonly modelId: string; + + private readonly config: AnthropicMessagesConfig; + + constructor(modelId: string, settings: {}, config: AnthropicMessagesConfig) { + this.modelId = modelId; + this.config = config; + } + + private async getArgs({ + mode, + prompt, + maxTokens, + temperature, + topP, + topK, + frequencyPenalty, + presencePenalty, + stopSequences, + responseFormat, + seed, + }: Parameters[0]) { + const type = mode.type; + + const warnings: LanguageModelV1CallWarning[] = []; + + if (frequencyPenalty != null) { + warnings.push({ + type: 'unsupported-setting', + setting: 'frequencyPenalty', + }); + } + + if (presencePenalty != null) { + warnings.push({ + type: 'unsupported-setting', + setting: 'presencePenalty', + }); + } + + if (seed != null) { + warnings.push({ + type: 'unsupported-setting', + setting: 'seed', + }); + } + + if (responseFormat != null && responseFormat.type !== 'text') { + warnings.push({ + type: 'unsupported-setting', + setting: 'responseFormat', + details: 'JSON response format is not supported.', + }); + } + + const messagesPrompt = convertToAnthropicMessagesPrompt(prompt); + + const baseArgs = { + anthropic_version: 'bedrock-2023-05-31', + + // model specific settings: + top_k: topK, + + // standardized settings: + max_tokens: maxTokens ?? 4096, // 4096: max model output tokens + temperature, + top_p: topP, + stop_sequences: stopSequences, + + // prompt: + system: messagesPrompt.system, + messages: messagesPrompt.messages, + }; + + switch (type) { + case 'regular': { + return { + args: { ...baseArgs, ...prepareToolsAndToolChoice(mode) }, + warnings, + }; + } + + case 'object-json': { + throw new UnsupportedFunctionalityError({ + functionality: 'json-mode object generation', + }); + } + + case 'object-tool': { + const { name, description, parameters } = mode.tool; + + return { + args: { + ...baseArgs, + tools: [{ name, description, input_schema: parameters }], + tool_choice: { type: 'tool', name }, + }, + warnings, + }; + } + + default: { + const _exhaustiveCheck: never = type; + throw new Error(`Unsupported type: ${_exhaustiveCheck}`); + } + } + } + + async doGenerate( + options: Parameters[0], + ): Promise>> { + const { args, warnings } = await this.getArgs(options); + + const { body } = await this.config.client.send( + new InvokeModelCommand({ + modelId: this.modelId, + contentType: 'application/json', + accept: 'application/json', + body: JSON.stringify(args), + }), + ); + + const parsedBody = safeParseJSON({ + text: body.transformToString(), + schema: anthropicMessagesResponseSchema, + }); + + if (!parsedBody.success) { + // TODO better error handling + throw new Error('Failed to parse response body'); + } + + const response = parsedBody.value; + + const { messages: rawPrompt, ...rawSettings } = args; + + // extract text + let text = ''; + for (const content of response.content) { + if (content.type === 'text') { + text += content.text; + } + } + + // // extract tool calls + let toolCalls: LanguageModelV1FunctionToolCall[] | undefined = undefined; + if (response.content.some(content => content.type === 'tool_use')) { + toolCalls = []; + for (const content of response.content) { + if (content.type === 'tool_use') { + toolCalls.push({ + toolCallType: 'function', + toolCallId: content.id, + toolName: content.name, + args: JSON.stringify(content.input), + }); + } + } + } + + return { + text, + toolCalls, + finishReason: mapAnthropicStopReason(response.stop_reason), + usage: { + promptTokens: response.usage.input_tokens, + completionTokens: response.usage.output_tokens, + }, + rawCall: { rawPrompt, rawSettings }, + warnings, + }; + } + + async doStream( + options: Parameters[0], + ): Promise>> { + throw new Error('Not implemented'); + + // const { args, warnings } = await this.getArgs(options); + + // const { responseHeaders, value: response } = await postJsonToApi({ + // url: `${this.config.baseURL}/messages`, + // headers: combineHeaders(this.config.headers(), options.headers), + // body: { + // ...args, + // stream: true, + // }, + // failedResponseHandler: anthropicFailedResponseHandler, + // successfulResponseHandler: createEventSourceResponseHandler( + // anthropicMessagesChunkSchema, + // ), + // abortSignal: options.abortSignal, + // fetch: this.config.fetch, + // }); + + // const { messages: rawPrompt, ...rawSettings } = args; + + // let finishReason: LanguageModelV1FinishReason = 'other'; + // const usage: { promptTokens: number; completionTokens: number } = { + // promptTokens: Number.NaN, + // completionTokens: Number.NaN, + // }; + + // const toolCallContentBlocks: Record< + // number, + // { + // toolCallId: string; + // toolName: string; + // jsonText: string; + // } + // > = {}; + + // return { + // stream: response.pipeThrough( + // new TransformStream< + // ParseResult>, + // LanguageModelV1StreamPart + // >({ + // transform(chunk, controller) { + // if (!chunk.success) { + // controller.enqueue({ type: 'error', error: chunk.error }); + // return; + // } + + // const value = chunk.value; + + // switch (value.type) { + // case 'ping': { + // return; // ignored + // } + + // case 'content_block_start': { + // const contentBlockType = value.content_block.type; + + // switch (contentBlockType) { + // case 'text': { + // return; // ignored + // } + + // case 'tool_use': { + // toolCallContentBlocks[value.index] = { + // toolCallId: value.content_block.id, + // toolName: value.content_block.name, + // jsonText: '', + // }; + // return; + // } + + // default: { + // const _exhaustiveCheck: never = contentBlockType; + // throw new Error( + // `Unsupported content block type: ${_exhaustiveCheck}`, + // ); + // } + // } + // } + + // case 'content_block_stop': { + // // when finishing a tool call block, send the full tool call: + // if (toolCallContentBlocks[value.index] != null) { + // const contentBlock = toolCallContentBlocks[value.index]; + + // controller.enqueue({ + // type: 'tool-call', + // toolCallType: 'function', + // toolCallId: contentBlock.toolCallId, + // toolName: contentBlock.toolName, + // args: contentBlock.jsonText, + // }); + + // delete toolCallContentBlocks[value.index]; + // } + + // return; + // } + + // case 'content_block_delta': { + // const deltaType = value.delta.type; + // switch (deltaType) { + // case 'text_delta': { + // controller.enqueue({ + // type: 'text-delta', + // textDelta: value.delta.text, + // }); + + // return; + // } + + // case 'input_json_delta': { + // const contentBlock = toolCallContentBlocks[value.index]; + + // controller.enqueue({ + // type: 'tool-call-delta', + // toolCallType: 'function', + // toolCallId: contentBlock.toolCallId, + // toolName: contentBlock.toolName, + // argsTextDelta: value.delta.partial_json, + // }); + + // contentBlock.jsonText += value.delta.partial_json; + + // return; + // } + + // default: { + // const _exhaustiveCheck: never = deltaType; + // throw new Error( + // `Unsupported delta type: ${_exhaustiveCheck}`, + // ); + // } + // } + // } + + // case 'message_start': { + // usage.promptTokens = value.message.usage.input_tokens; + // usage.completionTokens = value.message.usage.output_tokens; + // return; + // } + + // case 'message_delta': { + // usage.completionTokens = value.usage.output_tokens; + // finishReason = mapAnthropicStopReason(value.delta.stop_reason); + // return; + // } + + // case 'message_stop': { + // controller.enqueue({ type: 'finish', finishReason, usage }); + // return; + // } + + // case 'error': { + // controller.enqueue({ type: 'error', error: value.error }); + // return; + // } + + // default: { + // const _exhaustiveCheck: never = value; + // throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`); + // } + // } + // }, + // }), + // ), + // rawCall: { rawPrompt, rawSettings }, + // rawResponse: { headers: responseHeaders }, + // warnings, + // }; + } +} + +// limited version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +const anthropicMessagesResponseSchema = z.object({ + type: z.literal('message'), + content: z.array( + z.discriminatedUnion('type', [ + z.object({ + type: z.literal('text'), + text: z.string(), + }), + z.object({ + type: z.literal('tool_use'), + id: z.string(), + name: z.string(), + input: z.unknown(), + }), + ]), + ), + stop_reason: z.string().optional().nullable(), + usage: z.object({ + input_tokens: z.number(), + output_tokens: z.number(), + }), +}); + +// limited version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +const anthropicMessagesChunkSchema = z.discriminatedUnion('type', [ + z.object({ + type: z.literal('message_start'), + message: z.object({ + usage: z.object({ + input_tokens: z.number(), + output_tokens: z.number(), + }), + }), + }), + z.object({ + type: z.literal('content_block_start'), + index: z.number(), + content_block: z.discriminatedUnion('type', [ + z.object({ + type: z.literal('text'), + text: z.string(), + }), + z.object({ + type: z.literal('tool_use'), + id: z.string(), + name: z.string(), + }), + ]), + }), + z.object({ + type: z.literal('content_block_delta'), + index: z.number(), + delta: z.discriminatedUnion('type', [ + z.object({ + type: z.literal('input_json_delta'), + partial_json: z.string(), + }), + z.object({ + type: z.literal('text_delta'), + text: z.string(), + }), + ]), + }), + z.object({ + type: z.literal('content_block_stop'), + index: z.number(), + }), + z.object({ + type: z.literal('error'), + error: z.object({ + type: z.string(), + message: z.string(), + }), + }), + z.object({ + type: z.literal('message_delta'), + delta: z.object({ stop_reason: z.string().optional().nullable() }), + usage: z.object({ output_tokens: z.number() }), + }), + z.object({ + type: z.literal('message_stop'), + }), + z.object({ + type: z.literal('ping'), + }), +]); + +function prepareToolsAndToolChoice( + mode: Parameters[0]['mode'] & { + type: 'regular'; + }, +) { + // when the tools array is empty, change it to undefined to prevent errors: + const tools = mode.tools?.length ? mode.tools : undefined; + + if (tools == null) { + return { tools: undefined, tool_choice: undefined }; + } + + const mappedTools = tools.map(tool => ({ + name: tool.name, + description: tool.description, + input_schema: tool.parameters, + })); + + const toolChoice = mode.toolChoice; + + if (toolChoice == null) { + return { tools: mappedTools, tool_choice: undefined }; + } + + const type = toolChoice.type; + + switch (type) { + case 'auto': + return { tools: mappedTools, tool_choice: { type: 'auto' } }; + case 'required': + return { tools: mappedTools, tool_choice: { type: 'any' } }; + case 'none': + // Anthropic does not support 'none' tool choice, so we remove the tools: + return { tools: undefined, tool_choice: undefined }; + case 'tool': + return { + tools: mappedTools, + tool_choice: { type: 'tool', name: toolChoice.toolName }, + }; + default: { + const _exhaustiveCheck: never = type; + throw new Error(`Unsupported tool choice type: ${_exhaustiveCheck}`); + } + } +} diff --git a/packages/amazon-bedrock/src/anthropic/convert-to-anthropic-messages-prompt.ts b/packages/amazon-bedrock/src/anthropic/convert-to-anthropic-messages-prompt.ts new file mode 100644 index 000000000000..b42d77372bbc --- /dev/null +++ b/packages/amazon-bedrock/src/anthropic/convert-to-anthropic-messages-prompt.ts @@ -0,0 +1,219 @@ +import { + LanguageModelV1Message, + LanguageModelV1Prompt, + UnsupportedFunctionalityError, +} from '@ai-sdk/provider'; +import { convertUint8ArrayToBase64 } from '@ai-sdk/provider-utils'; +import { + AnthropicAssistantMessage, + AnthropicMessage, + AnthropicMessagesPrompt, + AnthropicUserMessage, +} from './anthropic-messages-prompt'; + +export function convertToAnthropicMessagesPrompt( + prompt: LanguageModelV1Prompt, +): AnthropicMessagesPrompt { + const blocks = groupIntoBlocks(prompt); + + let system: string | undefined = undefined; + const messages: AnthropicMessage[] = []; + + for (let i = 0; i < blocks.length; i++) { + const block = blocks[i]; + const type = block.type; + + switch (type) { + case 'system': { + if (system != null) { + throw new UnsupportedFunctionalityError({ + functionality: + 'Multiple system messages that are separated by user/assistant messages', + }); + } + + system = block.messages.map(({ content }) => content).join('\n'); + break; + } + + case 'user': { + // combines all user and tool messages in this block into a single message: + const anthropicContent: AnthropicUserMessage['content'] = []; + + for (const { role, content } of block.messages) { + switch (role) { + case 'user': { + for (const part of content) { + switch (part.type) { + case 'text': { + anthropicContent.push({ type: 'text', text: part.text }); + break; + } + case 'image': { + if (part.image instanceof URL) { + // The AI SDK automatically downloads images for user image parts with URLs + throw new UnsupportedFunctionalityError({ + functionality: 'Image URLs in user messages', + }); + } + + anthropicContent.push({ + type: 'image', + source: { + type: 'base64', + media_type: part.mimeType ?? 'image/jpeg', + data: convertUint8ArrayToBase64(part.image), + }, + }); + + break; + } + } + } + + break; + } + case 'tool': { + for (const part of content) { + anthropicContent.push({ + type: 'tool_result', + tool_use_id: part.toolCallId, + content: JSON.stringify(part.result), + is_error: part.isError, + }); + } + + break; + } + default: { + const _exhaustiveCheck: never = role; + throw new Error(`Unsupported role: ${_exhaustiveCheck}`); + } + } + } + + messages.push({ role: 'user', content: anthropicContent }); + + break; + } + + case 'assistant': { + // combines multiple assistant messages in this block into a single message: + const anthropicContent: AnthropicAssistantMessage['content'] = []; + + for (const { content } of block.messages) { + for (let j = 0; j < content.length; j++) { + const part = content[j]; + switch (part.type) { + case 'text': { + anthropicContent.push({ + type: 'text', + text: + // trim the last text part if it's the last message in the block + // because Anthropic does not allow trailing whitespace + // in pre-filled assistant responses + i === blocks.length - 1 && j === block.messages.length - 1 + ? part.text.trim() + : part.text, + }); + break; + } + + case 'tool-call': { + anthropicContent.push({ + type: 'tool_use', + id: part.toolCallId, + name: part.toolName, + input: part.args, + }); + break; + } + } + } + } + + messages.push({ role: 'assistant', content: anthropicContent }); + + break; + } + + default: { + const _exhaustiveCheck: never = type; + throw new Error(`Unsupported type: ${_exhaustiveCheck}`); + } + } + } + + return { + system, + messages, + }; +} + +type SystemBlock = { + type: 'system'; + messages: Array; +}; +type AssistantBlock = { + type: 'assistant'; + messages: Array; +}; +type UserBlock = { + type: 'user'; + messages: Array; +}; + +function groupIntoBlocks( + prompt: LanguageModelV1Prompt, +): Array { + const blocks: Array = []; + let currentBlock: SystemBlock | AssistantBlock | UserBlock | undefined = + undefined; + + for (const { role, content } of prompt) { + switch (role) { + case 'system': { + if (currentBlock?.type !== 'system') { + currentBlock = { type: 'system', messages: [] }; + blocks.push(currentBlock); + } + + currentBlock.messages.push({ role, content }); + break; + } + case 'assistant': { + if (currentBlock?.type !== 'assistant') { + currentBlock = { type: 'assistant', messages: [] }; + blocks.push(currentBlock); + } + + currentBlock.messages.push({ role, content }); + break; + } + case 'user': { + if (currentBlock?.type !== 'user') { + currentBlock = { type: 'user', messages: [] }; + blocks.push(currentBlock); + } + + currentBlock.messages.push({ role, content }); + break; + } + case 'tool': { + if (currentBlock?.type !== 'user') { + currentBlock = { type: 'user', messages: [] }; + blocks.push(currentBlock); + } + + currentBlock.messages.push({ role, content }); + break; + } + default: { + const _exhaustiveCheck: never = role; + throw new Error(`Unsupported role: ${_exhaustiveCheck}`); + } + } + } + + return blocks; +} diff --git a/packages/amazon-bedrock/src/anthropic/map-anthropic-stop-reason.ts b/packages/amazon-bedrock/src/anthropic/map-anthropic-stop-reason.ts new file mode 100644 index 000000000000..0c0d1aad8355 --- /dev/null +++ b/packages/amazon-bedrock/src/anthropic/map-anthropic-stop-reason.ts @@ -0,0 +1,17 @@ +import { LanguageModelV1FinishReason } from '@ai-sdk/provider'; + +export function mapAnthropicStopReason( + finishReason: string | null | undefined, +): LanguageModelV1FinishReason { + switch (finishReason) { + case 'end_turn': + case 'stop_sequence': + return 'stop'; + case 'tool_use': + return 'tool-calls'; + case 'max_tokens': + return 'length'; + default: + return 'other'; + } +} diff --git a/packages/amazon-bedrock/src/bedrock-provider.ts b/packages/amazon-bedrock/src/bedrock-provider.ts index a02b3ffef22f..08f4ac20e933 100644 --- a/packages/amazon-bedrock/src/bedrock-provider.ts +++ b/packages/amazon-bedrock/src/bedrock-provider.ts @@ -1,8 +1,10 @@ +import { LanguageModelV1 } from '@ai-sdk/provider'; import { generateId, loadSetting } from '@ai-sdk/provider-utils'; import { BedrockRuntimeClient, BedrockRuntimeClientConfig, } from '@aws-sdk/client-bedrock-runtime'; +import { BedrockAnthropicMessagesLanguageModel } from './anthropic/bedrock-anthropic-messages-language-model'; import { BedrockChatLanguageModel } from './bedrock-chat-language-model'; import { BedrockChatModelId, @@ -35,6 +37,8 @@ export interface AmazonBedrockProvider { modelId: BedrockChatModelId, settings?: BedrockChatSettings, ): BedrockChatLanguageModel; + + anthropicLanguageModel(modelId: string, settings?: {}): LanguageModelV1; } /** @@ -78,6 +82,11 @@ export function createAmazonBedrock( generateId, }); + const createAnthropicLanguageModel = (modelId: string, settings: {}) => + new BedrockAnthropicMessagesLanguageModel(modelId, settings, { + client: createBedrockRuntimeClient(), + }); + const provider = function ( modelId: BedrockChatModelId, settings?: BedrockChatSettings, @@ -92,6 +101,7 @@ export function createAmazonBedrock( }; provider.languageModel = createChatModel; + provider.anthropicLanguageModel = createAnthropicLanguageModel; return provider as AmazonBedrockProvider; } From 98648504aec531e3d826969bae1052dfca136971 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Fri, 9 Aug 2024 12:00:47 +0200 Subject: [PATCH 2/3] streaming --- .../stream-text/amazon-bedrock-anthropic.ts | 24 ++ ...drock-anthropic-messages-language-model.ts | 376 +++++++++--------- ...nvert-async-iterable-to-readable-stream.ts | 16 + .../anthropic/convert-uint-array-to-text.ts | 11 + 4 files changed, 247 insertions(+), 180 deletions(-) create mode 100644 examples/ai-core/src/stream-text/amazon-bedrock-anthropic.ts create mode 100644 packages/amazon-bedrock/src/anthropic/convert-async-iterable-to-readable-stream.ts create mode 100644 packages/amazon-bedrock/src/anthropic/convert-uint-array-to-text.ts diff --git a/examples/ai-core/src/stream-text/amazon-bedrock-anthropic.ts b/examples/ai-core/src/stream-text/amazon-bedrock-anthropic.ts new file mode 100644 index 000000000000..0ea7f38fad0f --- /dev/null +++ b/examples/ai-core/src/stream-text/amazon-bedrock-anthropic.ts @@ -0,0 +1,24 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { streamText } from 'ai'; +import dotenv from 'dotenv'; + +dotenv.config(); + +async function main() { + const result = await streamText({ + model: bedrock.anthropicLanguageModel( + 'anthropic.claude-3-haiku-20240307-v1:0', + ), + prompt: 'Invent a new holiday and describe its traditions.', + }); + + for await (const textPart of result.textStream) { + process.stdout.write(textPart); + } + + console.log(); + console.log('Token usage:', await result.usage); + console.log('Finish reason:', await result.finishReason); +} + +main().catch(console.error); diff --git a/packages/amazon-bedrock/src/anthropic/bedrock-anthropic-messages-language-model.ts b/packages/amazon-bedrock/src/anthropic/bedrock-anthropic-messages-language-model.ts index 0024ada73b23..2f93de633029 100644 --- a/packages/amazon-bedrock/src/anthropic/bedrock-anthropic-messages-language-model.ts +++ b/packages/amazon-bedrock/src/anthropic/bedrock-anthropic-messages-language-model.ts @@ -1,16 +1,23 @@ import { + EmptyResponseBodyError, LanguageModelV1, LanguageModelV1CallWarning, + LanguageModelV1FinishReason, LanguageModelV1FunctionToolCall, + LanguageModelV1StreamPart, UnsupportedFunctionalityError, } from '@ai-sdk/provider'; +import { parseJSON, ParseResult, safeParseJSON } from '@ai-sdk/provider-utils'; import { BedrockRuntimeClient, InvokeModelCommand, + InvokeModelWithResponseStreamCommand, + ResponseStream, } from '@aws-sdk/client-bedrock-runtime'; import { z } from 'zod'; +import { convertAsyncIterableToReadableStream } from './convert-async-iterable-to-readable-stream'; import { convertToAnthropicMessagesPrompt } from './convert-to-anthropic-messages-prompt'; -import { safeParseJSON } from '@ai-sdk/provider-utils'; +import { convertUint8ArrayToText } from './convert-uint-array-to-text'; import { mapAnthropicStopReason } from './map-anthropic-stop-reason'; type AnthropicMessagesConfig = { @@ -145,18 +152,11 @@ export class BedrockAnthropicMessagesLanguageModel implements LanguageModelV1 { }), ); - const parsedBody = safeParseJSON({ + const response = parseJSON({ text: body.transformToString(), schema: anthropicMessagesResponseSchema, }); - if (!parsedBody.success) { - // TODO better error handling - throw new Error('Failed to parse response body'); - } - - const response = parsedBody.value; - const { messages: rawPrompt, ...rawSettings } = args; // extract text @@ -199,177 +199,193 @@ export class BedrockAnthropicMessagesLanguageModel implements LanguageModelV1 { async doStream( options: Parameters[0], ): Promise>> { - throw new Error('Not implemented'); - - // const { args, warnings } = await this.getArgs(options); - - // const { responseHeaders, value: response } = await postJsonToApi({ - // url: `${this.config.baseURL}/messages`, - // headers: combineHeaders(this.config.headers(), options.headers), - // body: { - // ...args, - // stream: true, - // }, - // failedResponseHandler: anthropicFailedResponseHandler, - // successfulResponseHandler: createEventSourceResponseHandler( - // anthropicMessagesChunkSchema, - // ), - // abortSignal: options.abortSignal, - // fetch: this.config.fetch, - // }); - - // const { messages: rawPrompt, ...rawSettings } = args; - - // let finishReason: LanguageModelV1FinishReason = 'other'; - // const usage: { promptTokens: number; completionTokens: number } = { - // promptTokens: Number.NaN, - // completionTokens: Number.NaN, - // }; - - // const toolCallContentBlocks: Record< - // number, - // { - // toolCallId: string; - // toolName: string; - // jsonText: string; - // } - // > = {}; - - // return { - // stream: response.pipeThrough( - // new TransformStream< - // ParseResult>, - // LanguageModelV1StreamPart - // >({ - // transform(chunk, controller) { - // if (!chunk.success) { - // controller.enqueue({ type: 'error', error: chunk.error }); - // return; - // } - - // const value = chunk.value; - - // switch (value.type) { - // case 'ping': { - // return; // ignored - // } - - // case 'content_block_start': { - // const contentBlockType = value.content_block.type; - - // switch (contentBlockType) { - // case 'text': { - // return; // ignored - // } - - // case 'tool_use': { - // toolCallContentBlocks[value.index] = { - // toolCallId: value.content_block.id, - // toolName: value.content_block.name, - // jsonText: '', - // }; - // return; - // } - - // default: { - // const _exhaustiveCheck: never = contentBlockType; - // throw new Error( - // `Unsupported content block type: ${_exhaustiveCheck}`, - // ); - // } - // } - // } - - // case 'content_block_stop': { - // // when finishing a tool call block, send the full tool call: - // if (toolCallContentBlocks[value.index] != null) { - // const contentBlock = toolCallContentBlocks[value.index]; - - // controller.enqueue({ - // type: 'tool-call', - // toolCallType: 'function', - // toolCallId: contentBlock.toolCallId, - // toolName: contentBlock.toolName, - // args: contentBlock.jsonText, - // }); - - // delete toolCallContentBlocks[value.index]; - // } - - // return; - // } - - // case 'content_block_delta': { - // const deltaType = value.delta.type; - // switch (deltaType) { - // case 'text_delta': { - // controller.enqueue({ - // type: 'text-delta', - // textDelta: value.delta.text, - // }); - - // return; - // } - - // case 'input_json_delta': { - // const contentBlock = toolCallContentBlocks[value.index]; - - // controller.enqueue({ - // type: 'tool-call-delta', - // toolCallType: 'function', - // toolCallId: contentBlock.toolCallId, - // toolName: contentBlock.toolName, - // argsTextDelta: value.delta.partial_json, - // }); - - // contentBlock.jsonText += value.delta.partial_json; - - // return; - // } - - // default: { - // const _exhaustiveCheck: never = deltaType; - // throw new Error( - // `Unsupported delta type: ${_exhaustiveCheck}`, - // ); - // } - // } - // } - - // case 'message_start': { - // usage.promptTokens = value.message.usage.input_tokens; - // usage.completionTokens = value.message.usage.output_tokens; - // return; - // } - - // case 'message_delta': { - // usage.completionTokens = value.usage.output_tokens; - // finishReason = mapAnthropicStopReason(value.delta.stop_reason); - // return; - // } - - // case 'message_stop': { - // controller.enqueue({ type: 'finish', finishReason, usage }); - // return; - // } - - // case 'error': { - // controller.enqueue({ type: 'error', error: value.error }); - // return; - // } - - // default: { - // const _exhaustiveCheck: never = value; - // throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`); - // } - // } - // }, - // }), - // ), - // rawCall: { rawPrompt, rawSettings }, - // rawResponse: { headers: responseHeaders }, - // warnings, - // }; + const { args, warnings } = await this.getArgs(options); + + const { body } = await this.config.client.send( + new InvokeModelWithResponseStreamCommand({ + modelId: this.modelId, + contentType: 'application/json', + accept: 'application/json', + body: JSON.stringify(args), + }), + ); + + if (body == null) { + throw new EmptyResponseBodyError(); + } + + const stream = convertAsyncIterableToReadableStream(body); + + const response = stream.pipeThrough( + new TransformStream< + ResponseStream, + ParseResult> + >({ + transform(chunk, controller) { + const bytes = chunk.chunk?.bytes; + if (bytes != null) { + controller.enqueue( + safeParseJSON({ + text: convertUint8ArrayToText(bytes), + schema: anthropicMessagesChunkSchema, + }), + ); + } + }, + }), + ); + + const { messages: rawPrompt, ...rawSettings } = args; + + let finishReason: LanguageModelV1FinishReason = 'other'; + const usage: { promptTokens: number; completionTokens: number } = { + promptTokens: Number.NaN, + completionTokens: Number.NaN, + }; + + const toolCallContentBlocks: Record< + number, + { + toolCallId: string; + toolName: string; + jsonText: string; + } + > = {}; + + return { + stream: response.pipeThrough( + new TransformStream< + ParseResult>, + LanguageModelV1StreamPart + >({ + transform(chunk, controller) { + if (!chunk.success) { + controller.enqueue({ type: 'error', error: chunk.error }); + return; + } + + const value = chunk.value; + + switch (value.type) { + case 'ping': { + return; // ignored + } + + case 'content_block_start': { + const contentBlockType = value.content_block.type; + + switch (contentBlockType) { + case 'text': { + return; // ignored + } + + case 'tool_use': { + toolCallContentBlocks[value.index] = { + toolCallId: value.content_block.id, + toolName: value.content_block.name, + jsonText: '', + }; + return; + } + + default: { + const _exhaustiveCheck: never = contentBlockType; + throw new Error( + `Unsupported content block type: ${_exhaustiveCheck}`, + ); + } + } + } + + case 'content_block_stop': { + // when finishing a tool call block, send the full tool call: + if (toolCallContentBlocks[value.index] != null) { + const contentBlock = toolCallContentBlocks[value.index]; + + controller.enqueue({ + type: 'tool-call', + toolCallType: 'function', + toolCallId: contentBlock.toolCallId, + toolName: contentBlock.toolName, + args: contentBlock.jsonText, + }); + + delete toolCallContentBlocks[value.index]; + } + + return; + } + + case 'content_block_delta': { + const deltaType = value.delta.type; + switch (deltaType) { + case 'text_delta': { + controller.enqueue({ + type: 'text-delta', + textDelta: value.delta.text, + }); + + return; + } + + case 'input_json_delta': { + const contentBlock = toolCallContentBlocks[value.index]; + + controller.enqueue({ + type: 'tool-call-delta', + toolCallType: 'function', + toolCallId: contentBlock.toolCallId, + toolName: contentBlock.toolName, + argsTextDelta: value.delta.partial_json, + }); + + contentBlock.jsonText += value.delta.partial_json; + + return; + } + + default: { + const _exhaustiveCheck: never = deltaType; + throw new Error( + `Unsupported delta type: ${_exhaustiveCheck}`, + ); + } + } + } + + case 'message_start': { + usage.promptTokens = value.message.usage.input_tokens; + usage.completionTokens = value.message.usage.output_tokens; + return; + } + + case 'message_delta': { + usage.completionTokens = value.usage.output_tokens; + finishReason = mapAnthropicStopReason(value.delta.stop_reason); + return; + } + + case 'message_stop': { + controller.enqueue({ type: 'finish', finishReason, usage }); + return; + } + + case 'error': { + controller.enqueue({ type: 'error', error: value.error }); + return; + } + + default: { + const _exhaustiveCheck: never = value; + throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`); + } + } + }, + }), + ), + rawCall: { rawPrompt, rawSettings }, + warnings, + }; } } diff --git a/packages/amazon-bedrock/src/anthropic/convert-async-iterable-to-readable-stream.ts b/packages/amazon-bedrock/src/anthropic/convert-async-iterable-to-readable-stream.ts new file mode 100644 index 000000000000..f4f85c009e6a --- /dev/null +++ b/packages/amazon-bedrock/src/anthropic/convert-async-iterable-to-readable-stream.ts @@ -0,0 +1,16 @@ +export function convertAsyncIterableToReadableStream( + iterable: AsyncIterable, +) { + const iterator = iterable[Symbol.asyncIterator](); + return new ReadableStream({ + async pull(controller) { + const { done, value } = await iterator.next(); + if (done) controller.close(); + else controller.enqueue(value); + }, + + async cancel(reason) { + await iterator.return?.(reason); + }, + }); +} diff --git a/packages/amazon-bedrock/src/anthropic/convert-uint-array-to-text.ts b/packages/amazon-bedrock/src/anthropic/convert-uint-array-to-text.ts new file mode 100644 index 000000000000..00a59fa86e0a --- /dev/null +++ b/packages/amazon-bedrock/src/anthropic/convert-uint-array-to-text.ts @@ -0,0 +1,11 @@ +export function convertUint8ArrayToText(array: Uint8Array): string { + let latin1string = ''; + + // Note: regular for loop to support older JavaScript versions that + // do not support for..of on Uint8Array + for (let i = 0; i < array.length; i++) { + latin1string += String.fromCodePoint(array[i]); + } + + return latin1string; +} From ad055a86e43cf1925e1c75e1d6a36a3303db5d8d Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Fri, 9 Aug 2024 12:01:56 +0200 Subject: [PATCH 3/3] streamobject --- .../stream-object/amazon-bedrock-anthropic.ts | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 examples/ai-core/src/stream-object/amazon-bedrock-anthropic.ts diff --git a/examples/ai-core/src/stream-object/amazon-bedrock-anthropic.ts b/examples/ai-core/src/stream-object/amazon-bedrock-anthropic.ts new file mode 100644 index 000000000000..2a1eda7dfd2e --- /dev/null +++ b/examples/ai-core/src/stream-object/amazon-bedrock-anthropic.ts @@ -0,0 +1,34 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { streamObject } from 'ai'; +import dotenv from 'dotenv'; +import { z } from 'zod'; + +dotenv.config(); + +async function main() { + const result = await streamObject({ + model: bedrock.anthropicLanguageModel( + 'anthropic.claude-3-5-sonnet-20240620-v1:0', + ), + schema: z.object({ + characters: z.array( + z.object({ + name: z.string(), + class: z + .string() + .describe('Character class, e.g. warrior, mage, or thief.'), + description: z.string(), + }), + ), + }), + prompt: + 'Generate 3 character descriptions for a fantasy role playing game.', + }); + + for await (const partialObject of result.partialObjectStream) { + console.clear(); + console.log(partialObject); + } +} + +main().catch(console.error);