diff --git a/.github/workflows/deno.yml b/.github/workflows/deno.yml index 6051e41..68b3fbd 100644 --- a/.github/workflows/deno.yml +++ b/.github/workflows/deno.yml @@ -38,10 +38,10 @@ jobs: uses: actions/checkout@v4 - name: Setup Deno - uses: denoland/setup-deno@v1 + uses: denoland/setup-deno@v2 # uses: denoland/setup-deno@61fe2df320078202e33d7d5ad347e7dcfa0e8f31 # v1.1.2 with: - deno-version: v1.x + deno-version: v2.x.x # Uncomment this step to verify the use of 'deno fmt' on each commit. - name: Verify formatting @@ -63,25 +63,35 @@ jobs: strategy: matrix: include: - - name: linux + - name: linux-x64 vm: ubuntu-latest - - name: macosx + target: x86_64-unknown-linux-gnu + - name: linux-arm64 + vm: ubuntu-latest + target: aarch64-unknown-linux-gnu + - name: macosx-x64 + vm: macos-latest + target: x86_64-apple-darwin + - name: macosx-arm64 vm: macos-latest - - name: windows + target: aarch64-apple-darwin + - name: windows-x64 vm: windows-latest + target: x86_64-pc-windows-msvc + steps: - name: Setup repo uses: actions/checkout@v4 - name: Setup Deno - uses: denoland/setup-deno@v1 + uses: denoland/setup-deno@v2 # uses: denoland/setup-deno@61fe2df320078202e33d7d5ad347e7dcfa0e8f31 # v1.1.2 with: - deno-version: v1.x + deno-version: v2.x.x - name: Build binary - run: deno compile --allow-net --allow-env --output gpt index.ts + run: deno compile --allow-net --allow-env --target ${{ matrix.target }} --output gpt index.ts - name: Binary upload (Unix, MacOS) # Artifact upload only occurs when tag matches diff --git a/README.md b/README.md index 4158188..0c12bbf 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ - + @@ -10,10 +10,16 @@ Command-line interface that enables interactive chat with LLMs. # Quick start ``` -$ curl -LO https://github.com/u1and0/gpt-cli/releases/download/v0.7.0/gpt-cli-linux.zip +# Install binary +$ curl -LO https://github.com/u1and0/gpt-cli/releases/download/v0.8.0/gpt-cli-linux.zip $ unzip gpt-cli.zip $ chmod 755 gpt $ sudo ln -s ./gpt /usr/bin + +# Setup API key +$ export OPENAI_API_KEY='sk-*****' + +# Run $ gpt Ctrl-D to confirm input, q or exit to end chat You: hi @@ -46,7 +52,7 @@ You have 3 options. The simplest way. ``` -$ curl -LO https://github.com/u1and0/gpt-cli/releases/download/v0.7.0/gpt-cli-linux.zip +$ curl -LO https://github.com/u1and0/gpt-cli/releases/download/v0.8.0/gpt-cli-linux.zip $ unzip gpt-cli-linux.zip $ chmod 755 gpt $ sudo ln -s ./gpt /usr/bin @@ -104,7 +110,7 @@ export OPENAI_API_KEY='sk-*****' export ANTHROPIC_API_KEY='sk-ant-*****' ``` -### Goolgle API (for Gemini) +### Google API (for Gemini) [Get Google API key](https://aistudio.google.com/app/apikey), then set environment argument. @@ -112,6 +118,22 @@ export ANTHROPIC_API_KEY='sk-ant-*****' export GOOGLE_API_KEY='*****' ``` +### Groq API + +[Get Groq API key](https://console.groq.com/keys), then set environment argument. + +``` +export GROQ_API_KEY='*****' +``` + +### Together AI API + +[Get Together AI API key](https://api.together.xyz/settings/api-keys), then set environment argument. + +``` +export TOGETHER_AI_API_KEY='*****' +``` + ### Replicate API (for Open Models) [Get Replicate API token](https://replicate.com/account/api-tokens), then set environment argument. @@ -138,7 +160,7 @@ $ gpt -m gpt-4o-mini -x 1000 -t 1.0 [OPTIONS] PROMPT |--------------|-------------|------|----| | -v | --version | boolean | Show version | | -h | --help | boolean | Show this message | -| -m | --model | string | OpenAI, Anthropic, Google, Replicate, Ollama model (default gpt-4o-mini) | +| -m | --model | string | LLM model (default gpt-4o-mini) | | -x | --max\_tokens | number | Number of AI answer tokens (default 1000) | | -t | --temperature | number | Higher number means more creative answers, lower number means more exact answers (default 1.0) | | -u | --url | string | URL and port number for ollama server | @@ -153,38 +175,54 @@ A Questions for Model ## Models - [OpenAI](https://platform.openai.com/docs/models) - gpt-4o-mini - - gpt-4o... + - gpt-4o + - o1 + - o1-preview + - o1-mini... - [Anthropic](https://docs.anthropic.com/claude/docs/models-overview) - claude-3-opus-20240229 - - claude-3-haiku-20240307 - - claude-instant-1.2... + - claude-3-5-sonnet-latest + - claude-3-5-haiku-latest - [Gemini](https://ai.google.dev/gemini-api/docs/models/gemini) - gemini-1.5-pro-latest - - gemini-pro + - gemini-2.0-flash-exp... +- [Groq](https://console.groq.com/docs/models) + - groq/llama3-groq-70b-8192-tool-use-preview + - groq/llama-3.3-70b-specdec + - groq/llama3.1-70b-specdec + - groq/llama-3.2-1b-preview + - groq/llama-3.2-3b-preview... +- [TogetherAI](https://api.together.ai/models) + - togetherai/meta-llama/Llama-3.3-70B-Instruct-Turbo + - togetherai/Qwen/QwQ-32B-Preview + - togetherai/meta-llama/Llama-3.1-405B-Instruct-Turbo + - togetherai/google/gemma-2-27b-it + - togetherai/mistralai/Mistral-7B-Instruct-v0.3... - [Replicate](https://replicate.com/models) - - meta/meta-llama-3-70b-instruct - - meta/llama-2-7b-chat - - mistralai/mistral-7b-instruct-v0.2 - - mistralai/mixtral-8x7b-instruct-v0.1 - - snowflake/snowflake-arctic-instruct - - replicate/flan-t5-xl... + - replicate/meta/meta-llama-3-70b-instruct + - replicate/meta/llama-2-7b-chat + - replicate/mistralai/mistral-7b-instruct-v0.2 + - replicate/mistralai/mixtral-8x7b-instruct-v0.1 + - replicate/snowflake/snowflake-arctic-instruct + - replicate/replicate/flan-t5-xl... - [Ollama](https://ollama.com/library) ** Using before "$ ollama serve" locally ** - - phi3 - - llama3:70b - - mixtral:8x7b-text-v0.1-q5\_K\_M... + - ollama/phi3 + - ollama/llama3:70b + - ollama/mixtral:8x7b-text-v0.1-q5\_K\_M... ## / command Help (/commands): -- /?, /help Help for a command -- /clear Clear session context -- /bye Exit +- /?, /help Help for a command +- /clear Clear session context +- /modelStack Show model's history +- /bye,/exit,/quit Exit ## @ command Help (@commands): Change model while asking. -- @{ModelName} Change LLM model -- ex) @gemini-1.5-pro any prompt... +- @ModelName Change LLM model +- ex) @gemini-1.5-pro your question... ## Test @@ -215,12 +253,12 @@ hook_add = ''' " Create test code command! -nargs=0 -range GPTGenerateTest ,call gptcli#GPT('You are the best code tester. Please write test code that covers all cases to try the given code.', { "temperature": 0.5, "model": "claude-3-haiku-20240307" }) command! -nargs=0 -range GPTErrorBustor ,call gptcli#GPT('Your task is to analyze the provided code snippet, identify any bugs or errors present, and provide a corrected version of the code that resolves these issues. Explain the problems you found in the original code and how your fixes address them. The corrected code should be functional, efficient, and adhere to best practices in programming.', {"temperature": 0.5, "model": "claude-3-sonnet-20240229"}) - command! -nargs=0 -range GPTCodeOptimizer ,call gptcli#GPT("Your task is to analyze the provided code snippet and suggest improvements to optimize its performance. Identify areas where the code can be made more efficient, faster, or less resource-intensive. Provide specific suggestions for optimization, along with explanations of how these changes can enhance the code performance. The optimized code should maintain the same functionality as the original code while demonstrating improved efficiency.", { "model": "meta/meta-llama-3-70b-instruct" }) + command! -nargs=0 -range GPTCodeOptimizer ,call gptcli#GPT("Your task is to analyze the provided code snippet and suggest improvements to optimize its performance. Identify areas where the code can be made more efficient, faster, or less resource-intensive. Provide specific suggestions for optimization, along with explanations of how these changes can enhance the code performance. The optimized code should maintain the same functionality as the original code while demonstrating improved efficiency.", { "model": "replicate/meta/meta-llama-3-70b-instruct" }) " Any system prompt command! -nargs=? -range GPTComplete ,call gptcli#GPT(, { "model": "claude-3-haiku-20240307" }) " Chat with GPT - command! -nargs=? GPTChat call gptcli#GPTWindow(, { "model": "phi3:instruct", "url": "http://localhost:11434"}) + command! -nargs=? GPTChat call gptcli#GPTWindow(, { "model": "ollama/phi3:instruct", "url": "http://localhost:11434"}) ``` ![Peek 2024-04-01 03-35.gif](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/113494/f243f19b-ee47-9821-5899-7ed2acc17320.gif) diff --git a/deno.json b/deno.json index 22249e3..6717cd9 100644 --- a/deno.json +++ b/deno.json @@ -4,5 +4,10 @@ }, "test": { "exclude": ["node_modules/"] - } + }, + "dependencies": { + "@langchain/groq": "^0.1.2" + }, + "nodeModulesDir": "auto" } + diff --git a/index.ts b/index.ts index 1cfb097..c693ca2 100644 --- a/index.ts +++ b/index.ts @@ -25,9 +25,21 @@ import { commandMessage, helpMessage } from "./lib/help.ts"; import { LLM, Message } from "./lib/llm.ts"; import { getUserInputInMessage, readStdin } from "./lib/input.ts"; import { Params, parseArgs } from "./lib/params.ts"; -import { Command, extractAtModel, isCommand } from "./lib/command.ts"; - -const VERSION = "v0.7.0"; +import { + Command, + extractAtModel, + handleSlashCommand, + isAtCommand, + isSlashCommand, + modelStack, +} from "./lib/command.ts"; + +const VERSION = "v0.8.0"; + +/** 灰色のテキストで表示 */ +function consoleInfoWithGrayText(s: string): void { + console.info(`\x1b[90m${s}\x1b[0m`); +} const llmAsk = async (params: Params) => { params.debug && console.debug(params); @@ -48,51 +60,46 @@ const llmAsk = async (params: Params) => { } // 対話的回答 + consoleInfoWithGrayText(commandMessage); while (true) { // ユーザーからの入力待ち - let humanMessage = await getUserInputInMessage(messages); - - /** /commandを実行する - * Help: ヘルプメッセージを出力する - * Clear: systemp promptを残してコンテキストを削除する - * Bye: コマンドを終了する - */ - if (isCommand(humanMessage)) { - switch (humanMessage) { - case Command.Help: { - console.log(commandMessage); - continue; // Slashコマンドを処理したら次のループへ - } - case Command.Clear: { - // system promptが設定されていれば、それを残してコンテキストクリア - console.log("Context clear successful"); - messages = params.systemPrompt - ? [new SystemMessage(params.systemPrompt)] - : []; - continue; // Slashコマンドを処理したら次のループへ - } - case Command.Bye: { - Deno.exit(0); - } - } - } else if (humanMessage?.content.toString().startsWith("@")) { + let humanMessage: HumanMessage | Command = await getUserInputInMessage( + messages, + ); + + // /commandを実行する + if (isSlashCommand(humanMessage)) { + messages = handleSlashCommand(humanMessage, messages); + continue; + } else if (isAtCommand(humanMessage)) { // @Model名で始まるinput はllmモデルを再指定する const { model, message } = extractAtModel( humanMessage.content.toString(), ); // モデル名指定以外のプロンプトがなければ前のプロンプトを引き継ぐ。 - humanMessage = message ? message : messages.at(-2); + // 前のプロンプトもなければ空のHumanMessageを渡す + humanMessage = message || messages.at(-2) || new HumanMessage(""); + + // @コマンドで指定したモデルのパースに成功したら + // モデルスタックに追加して新しいモデルで会話を始める。 + // パースに失敗したら、以前のモデルを復元してエラー表示して + // 前のモデルに戻して会話を継続。 if (model) { + const modelBackup = params.model; params.model = model; - llm = new LLM(params); + try { + llm = new LLM(params); + } catch (error: unknown) { + console.error(error); + params.model = modelBackup; + continue; + } + modelStack.push(model); } } - // 最後のメッセージがHumanMessageではない場合 // ユーザーからの問いを追加 - if (humanMessage) { - messages.push(humanMessage); - } + messages.push(humanMessage); // console.debug(messages); // AIからの回答を追加 const aiMessage = await llm.ask(messages); @@ -119,6 +126,8 @@ const main = async () => { Deno.exit(0); } + // modelStackに使用した最初のモデルを追加 + modelStack.push(params.model); // 標準入力をチェック const stdinContent: string | null = await readStdin(); if (stdinContent) { diff --git a/lib/command.ts b/lib/command.ts index 5521112..708152e 100644 --- a/lib/command.ts +++ b/lib/command.ts @@ -1,9 +1,15 @@ -import { HumanMessage } from "npm:@langchain/core/messages"; +import { HumanMessage, SystemMessage } from "npm:@langchain/core/messages"; +import { commandMessage } from "./help.ts"; +import { Message } from "./llm.ts"; + +/** この会話で使用したLLM モデルの履歴 */ +export const modelStack: string[] = []; export type _Command = | "/help" | "/?" | "/clear" + | "/modelStack" | "/bye" | "/exit" | "/quit"; @@ -11,27 +17,32 @@ export type _Command = export enum Command { Help = "HELP", Clear = "CLEAR", + ModelStack = "MODELSTACK", Bye = "BYE", } // Command 型の型ガード -export const isCommand = (value: unknown): value is Command => { +export const isSlashCommand = (value: unknown): value is Command => { return Object.values(Command).includes(value as Command); }; // Commandに指定したいずれかの数値を返す -export const newSlashCommand = (input: string): Command | undefined => { +export const newSlashCommand = (input: string): Command => { const input0 = input.trim().split(/[\s\n\t]+/, 1)[0]; const commandMap: Record<_Command, Command> = { "/help": Command.Help, "/?": Command.Help, "/clear": Command.Clear, + "/modelStack": Command.ModelStack, "/bye": Command.Bye, "/exit": Command.Bye, "/quit": Command.Bye, }; - - return commandMap[input0 as _Command]; + const command = commandMap[input0 as _Command]; + if (!command) { + throw new Error(`Invalid command. ${input0}`); + } + return command; }; type ModelMessage = { model?: string; message?: HumanMessage }; @@ -48,3 +59,46 @@ export const extractAtModel = (input: string): ModelMessage => { const message = input1 ? new HumanMessage(input1) : undefined; return { model, message }; }; + +export function handleSlashCommand( + command: Command, + messages: Message[], +): Message[] { + switch (command) { + case Command.Help: { + console.log(commandMessage); + break; // Slashコマンドを処理したら次のループへ + } + case Command.Clear: { + console.log("Context clear successful"); + // SystemMessage 以外は捨てて新しい配列を返す + return messages.filter((message: Message) => { + if (message instanceof SystemMessage) { + return message; + } + }); + } + // 使用したモデルの履歴を表示する + case Command.ModelStack: { + console.log(`You were chat with them...\n${modelStack.join("\n")}`); + break; + } + case Command.Bye: { + Deno.exit(0); + } + } + // messagesをそのまま返す + return messages; +} + +/** @が最初につく場合を判定 */ +export const isAtCommand = (humanMessage: unknown): boolean => { + if (!(humanMessage instanceof HumanMessage)) { + return false; + } + const content = humanMessage.content.toString(); + if (!content) { + return false; + } + return content.startsWith("@"); +}; diff --git a/lib/command_test.ts b/lib/command_test.ts deleted file mode 100644 index 848672c..0000000 --- a/lib/command_test.ts +++ /dev/null @@ -1,30 +0,0 @@ -import { assertEquals } from "https://deno.land/std@0.224.0/assert/assert_equals.ts"; -import { Command, extractAtModel, newSlashCommand } from "../lib/command.ts"; - -Deno.test("SlashCommand constructor", () => { - const testCases: [string, Command | string][] = [ - ["/help", "HELP"], - ["/help my content", "HELP"], // multi word test - ["/help\nmy content", "HELP"], // \n trim test - [" /help my content", "HELP"], // trim() test - ["/help\tmy content", "HELP"], // \t trim test - ]; - - for (const [args, expected] of testCases) { - const actual = newSlashCommand(args); - assertEquals(actual, expected); - } -}); - -Deno.test("ユーザーの入力が@から始まると、@に続くモデル名を返す", () => { - const testCases: [string, string | undefined][] = [ - ["@modelName arg1 arg2", "modelName"], // 行頭に@が入るとモデル名を返す - [" @modelName arg1 arg2", undefined], // 行頭にスペースが入ると@コマンドではない - ["plain text", undefined], - ]; - - for (const [args, expected] of testCases) { - const actual = extractAtModel(args); - assertEquals(actual.model, expected); - } -}); diff --git a/lib/help.ts b/lib/help.ts index 23ab665..b7fcaa5 100644 --- a/lib/help.ts +++ b/lib/help.ts @@ -1,10 +1,13 @@ export const commandMessage = ` - Help: - /?, /help Help for a command - /clear Clear session context - /bye Exit - @{ModelName} Change LLM model - ex) @gemini-1.5-pro any prompt... +Ctrl+D to confirm input. + +Help: + /?, /help Help for a command + /clear Clear session context + /modelStack Show model's history + /bye,/exit,/quit Exit + @ModelName Change LLM model + ex) @gemini-1.5-pro your question... `; export const helpMessage = @@ -16,7 +19,7 @@ export const helpMessage = Options: -v, --version: boolean Show version -h, --help: boolean Show this message - -m, --model: string OpenAI, Anthropic, Google, Replicate, Ollama model (default gpt-4o-mini) + -m, --model: string LLM model (default gpt-4o-mini) -x, --max-tokens: number Number of AI answer tokens (default 1000) -t, --temperature: number Higher number means more creative answers, lower number means more exact answers (default 1.0) -u, --url: string URL and port number for ollama server @@ -27,7 +30,10 @@ export const helpMessage = Models: - [OpenAI](https://platform.openai.com/docs/models) - gpt-4o-mini - - gpt-4o... + - gpt-4o + - o1 + - o1-preview + - o1-mini... - [Anthropic](https://docs.anthropic.com/claude/docs/models-overview) - claude-3-5-sonnet-20241022 - claude-3-5-sonnet-latest @@ -36,18 +42,29 @@ export const helpMessage = - claude-instant-1.2... - [Gemini](https://ai.google.dev/gemini-api/docs/models/gemini) - gemini-1.5-pro-latest - - gemini-1.5-flash - - gemini-pro... + - gemini-2.0-flash-exp... + - [Groq](https://console.groq.com/docs/models) + - groq/llama3-groq-70b-8192-tool-use-preview + - groq/llama-3.3-70b-specdec + - groq/llama3.1-70b-specdec + - groq/llama-3.2-1b-preview + - groq/llama-3.2-3b-preview + - [TogetherAI](https://api.together.ai/models) + - togetherai/meta-llama/Llama-3.3-70B-Instruct-Turbo + - togetherai/Qwen/QwQ-32B-Preview + - togetherai/meta-llama/Llama-3.1-405B-Instruct-Turbo + - togetherai/google/gemma-2-27b-it + - togetherai/mistralai/Mistral-7B-Instruct-v0.3... - [Replicate](https://replicate.com/models) - - meta/meta-llama-3-70b-instruct - - meta/llama-2-7b-chat - - mistralai/mistral-7b-instruct-v0.2 - - mistralai/mixtral-8x7b-instruct-v0.1 - - snowflake/snowflake-arctic-instruct - - replicate/flan-t5-xl... + - replicate/meta/meta-llama-3-70b-instruct + - replicate/meta/llama-2-7b-chat + - replicate/mistralai/mistral-7b-instruct-v0.2 + - replicate/mistralai/mixtral-8x7b-instruct-v0.1 + - replicate/snowflake/snowflake-arctic-instruct + - replicate/replicate/flan-t5-xl... - [Ollama](https://ollama.com/library) ** Using before "$ ollama serve" locally ** - - phi3 - - llama3:70b - - mixtral:8x7b-text-v0.1-q5_K_M... + - ollama/phi3 + - ollama/llama3:70b + - ollama/mixtral:8x7b-text-v0.1-q5_K_M... ${commandMessage} `; diff --git a/lib/input.ts b/lib/input.ts index 12df9c0..27c3cdd 100644 --- a/lib/input.ts +++ b/lib/input.ts @@ -3,24 +3,44 @@ import { HumanMessage } from "npm:@langchain/core/messages"; import { Message } from "./llm.ts"; import { Command, newSlashCommand } from "./command.ts"; -/** ユーザーの入力とシステムプロンプトをmessages内にセットする */ +/** ユーザーの入力を返す + * メッセージ配列から最後のユーザー入力を取得、もしくは新しいユーザー入力を待ち受ける + * + * - 最後のメッセージがユーザーからのものでない場合: ユーザーから新しい入力を取得 + * - スラッシュコマンドの場合: Command オブジェクト + * - 通常のメッセージの場合: HumanMessage オブジェクト + * - 最後のメッセージがユーザーからのものの場合: そのHumanMessageを返す + * + * @param {Message[]}: messages - 会話履歴のメッセージ配列 + * @returns {HumanMessage | Command} - ユーザーの入力、またはSlash Command + */ export async function getUserInputInMessage( messages: Message[], -): Promise { +): Promise { // 最後のMessageがユーザーからのメッセージではない場合、 // endlessInput()でユーザーからの質問を待ち受ける const lastMessage: Message | undefined = messages.at(-1); // console.debug(lastMessage); - if (!(lastMessage instanceof HumanMessage)) { - const input = await endlessInput(); - // / から始まる入力はコマンド解釈を試みる - if (input.trim().startsWith("/")) { - const cmd = newSlashCommand(input); - if (cmd) return cmd; - } + if (lastMessage instanceof HumanMessage) { + return lastMessage; + } + // 入力が何かあるまで入力を施す + const input: string = await endlessInput(); + + // / から始まらなければ、ユーザーの入力として返す + if (!input.trim().startsWith("/")) { + return new HumanMessage(input); + } + + // / から始まる入力はコマンド解釈を試みる + try { + const cmd = newSlashCommand(input); + return cmd; + } catch { + // Invalid command errorの場合は、 + // /を含めてHumanMessageとして返す return new HumanMessage(input); } - return; // console.debug(messages); } diff --git a/lib/llm.ts b/lib/llm.ts index e27ca52..0e188f2 100644 --- a/lib/llm.ts +++ b/lib/llm.ts @@ -2,6 +2,8 @@ import { ChatOpenAI } from "npm:@langchain/openai"; import { ChatAnthropic } from "npm:@langchain/anthropic"; import { ChatOllama } from "npm:@langchain/community/chat_models/ollama"; import { ChatGoogleGenerativeAI } from "npm:@langchain/google-genai"; +import { ChatGroq } from "npm:@langchain/groq"; +import { ChatTogetherAI } from "npm:@langchain/community/chat_models/togetherai"; import Replicate from "npm:replicate"; import ServerSentEvent from "npm:replicate"; import { @@ -16,75 +18,29 @@ import { Params } from "./params.ts"; /** AIMessage */ export type Message = AIMessage | HumanMessage | SystemMessage | never; //{ role: Role; content: string }; -type Model = `${string}/${string}`; + +/** replicateで使うモデルは以下の形式 + * owner/name or owner/name:version + */ +type ReplicateModel = `${string}/${string}`; + +/** ReplicateModel型であることを保証する */ +const isReplicateModel = (value: unknown): value is ReplicateModel => { + return typeof value === "string" && + value.includes("/") && + value.split("/").length === 2; +}; + +type OpenModel = ChatGroq | ChatTogetherAI | ChatOllama | Replicate; +type CloseModel = ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI; /** Chatインスタンスを作成する * @param: Params - LLMのパラメータ、モデル */ export class LLM { - public readonly transrator: - | ChatOpenAI - | ChatAnthropic - | ChatOllama - | ChatGoogleGenerativeAI - | Replicate - | undefined; + public readonly transrator?: OpenModel | CloseModel; constructor(private readonly params: Params) { - this.transrator = (() => { - const replicateModels = [ - "llama", - "mistral", - "command-r", - "llava", - "mixtral", - "deepseek", - "phi", - "hermes", - "orca", - "falcon", - "dolphin", - "gemma", - ]; - const replicateModelPatterns = replicateModels.map((m: string) => - new RegExp(m) - ); - if (params.url !== undefined) { - // params.modelの文字列にollamaModelsのうちの一部が含まれていたらtrue - // ollamaModelPatterns.some((p: RegExp) => p.test(params.model)) - return new ChatOllama({ - baseUrl: params.url, // http://yourIP:11434 - model: params.model, // "llama2:7b-chat", codellama:13b-fast-instruct, elyza:13b-fast-instruct ... - temperature: params.temperature, - // maxTokens: params.maxTokens, // Not implemented yet on Langchain - }); - } else if (params.model.startsWith("gpt")) { - return new ChatOpenAI({ - modelName: params.model, - temperature: params.temperature, - maxTokens: params.maxTokens, - }); - } else if (params.model.startsWith("claude")) { - return new ChatAnthropic({ - modelName: params.model, - temperature: params.temperature, - maxTokens: params.maxTokens, - }); - } else if (params.model.startsWith("gemini")) { - return new ChatGoogleGenerativeAI({ - model: params.model, - temperature: params.temperature, - maxOutputTokens: params.maxTokens, - }); - } else if ( - // params.modelの文字列にollamaModelsのうちの一部が含まれていたらtrue - replicateModelPatterns.some((p: RegExp) => p.test(params.model)) && // replicateモデルのパターンに一致 - (params.model as Model) === params.model // Model型に一致 - ) { - return new Replicate(); - } else { - throw new Error(`model not found "${params.model}"`); - } - })(); + this.transrator = llmConstructor(params); } /** AI へ一回限りの質問をし、回答を出力して終了する */ @@ -134,7 +90,7 @@ export class LLM { } else { const input = this.generateInput(messages); return (this.transrator as Replicate).stream( - this.params.model as Model, + this.params.model as ReplicateModel, { input }, ) as AsyncGenerator; } @@ -202,7 +158,7 @@ ${sys?.content ?? "You are helpful assistant."} }; // HumanMessageは[INST][/INST] で囲む // AIMessageは何もしない - const generatePrompt = (messages: (AIMessage | HumanMessage)[]): string => { + const surroundINST = (messages: (AIMessage | HumanMessage)[]): string => { return messages.map((message: AIMessage | HumanMessage, index: number) => { if (index === 0) { return `${message.content} [/INST]`; @@ -214,7 +170,7 @@ ${sys?.content ?? "You are helpful assistant."} }) .join("\n"); }; - const humanAIPrompt = generatePrompt(humanAIMessages); + const humanAIPrompt = surroundINST(humanAIMessages); return `[INST] ${systemPrompt}${humanAIPrompt}`; } @@ -231,3 +187,177 @@ async function* streamEncoder( yield str; } } + +type ModelMap = { [key: string]: (params: Params) => CloseModel }; + +// Platformオプション +// llamaモデルは共通のオープンモデルなので、 +// どこで実行するかをオプションで決める必要がある +export const platformList = [ + "ollama", + "groq", + "togetherai", + "replicate", +] as const; + +export type Platform = (typeof platformList)[number]; + +/** Platform型であることを保証する */ +export function isPlatform(value: unknown): value is Platform { + return typeof value === "string" && + platformList.includes(value as Platform); +} + +/** Platformごとに返すモデルのインスタンスを返す関数 */ +type PlatformMap = { [key in Platform]: (params: Params) => OpenModel }; + +/** LLM クラスのtransratorプロパティをparamsから判定し、 + * LLM インスタンスを生成して返す。 + * @param{Params} params - command line arguments parsed by parseArgs() + * @return : LLM model + * @throws{Error} model not found "${params.model}" + */ +function llmConstructor(params: Params): OpenModel | CloseModel { + const modelMap: ModelMap = { + "^gpt": createOpenAIInstance, + "^o[0-9]": createOpenAIOModelInstance, + "^claude": createAnthropicInstance, + "^gemini": createGoogleGenerativeAIInstance, + } as const; + + const platformMap: PlatformMap = { + "groq": createGroqInstance, + "togetherai": createTogetherAIInstance, + "ollama": createOllamaInstance, + "replicate": createReplicateInstance, + } as const; + + // Closed modelがインスタンス化できるか + // 正規表現でマッチング + const createInstance = Object.keys(modelMap).find((regex) => + new RegExp(regex).test(params.model) + ); + + // Closed modelが見つかればそれをインスタンス化して返す + if (createInstance !== undefined) { + return modelMap[createInstance](params); + } + + // Closed modelでマッチするモデルが見つからなかった場合、 + // Open model がインスタンス化できるか。 + // + // llamaなどのオープンモデルはモデル名ではなく、 + // platform名で判定する + + // platformが特定できないときは空文字が返る + const { platform, model } = parsePlatform(params.model); + // platformがオプションに指定されていなければエラー + if (!isPlatform(platform)) { + throw new Error( + `unknown platform "${platform}", choose from ${platformList.join(", ")}`, + ); + } + + // platformMap からオプションに指定したものがなければエラー + const createInstanceFromPlatform = platformMap[platform]; + if (createInstanceFromPlatform === undefined) { + throw new Error(`unknown model ${model}`); + } + + return createInstanceFromPlatform(params); +} + +const createOpenAIInstance = (params: Params): ChatOpenAI => { + return new ChatOpenAI({ + modelName: params.model, + temperature: params.temperature, + maxTokens: params.maxTokens, + }); +}; + +const createOpenAIOModelInstance = (params: Params): ChatOpenAI => { + return new ChatOpenAI({ + modelName: params.model, + temperature: params.temperature, + // max_completion_tokens: params.maxTokens, + }); +}; + +const createAnthropicInstance = (params: Params): ChatAnthropic => { + return new ChatAnthropic({ + modelName: params.model, + temperature: params.temperature, + maxTokens: params.maxTokens, + }); +}; + +const createGoogleGenerativeAIInstance = ( + params: Params, +): ChatGoogleGenerativeAI => { + return new ChatGoogleGenerativeAI({ + model: params.model, + temperature: params.temperature, + maxOutputTokens: params.maxTokens, + }); +}; + +const createGroqInstance = (params: Params): ChatGroq => { + const { platform: _, model } = parsePlatform(params.model); + return new ChatGroq({ + model: model, + temperature: params.temperature, + maxTokens: params.maxTokens, + }); +}; + +const createTogetherAIInstance = (params: Params): ChatTogetherAI => { + const { platform: _, model } = parsePlatform(params.model); + return new ChatTogetherAI({ + model: model, + temperature: params.temperature, + maxTokens: params.maxTokens, + }); +}; + +const createOllamaInstance = (params: Params): ChatOllama => { + // ollamaの場合は、ollamaが動作するサーバーのbaseUrlが必須 + if (params.url === undefined) { + throw new Error( + "ollama needs URL parameter with `--url http://your.host:11434`", + ); + } + const { platform: _, model } = parsePlatform(params.model); + return new ChatOllama({ + baseUrl: params.url, // http://yourIP:11434 + model: model, // "llama2:7b-chat", codellama:13b-fast-instruct, elyza:13b-fast-instruct ... + temperature: params.temperature, + // maxTokens: params.maxTokens, // Not implemented yet on Langchain + }); +}; + +const createReplicateInstance = (params: Params): Replicate => { + const { platform: _, model } = parsePlatform(params.model); + if (isReplicateModel(model)) { + return new Replicate(); + } else { + throw new Error( + `Invalid reference to model version: "${model}". Expected format: owner/name or owner/name:version `, + ); + } +}; + +/** 1つ目の"/"で引数を分割して、 + * 1つ目をplatformとして、 + * 2つめ移行をmodelとして返す + */ +export function parsePlatform( + model: string, +): { platform: string; model: string } { + const parts = model.split("/"); + if (parts.length < 2) { + return { platform: "", model: model }; + } + const platform = parts[0]; + const modelName = parts.slice(1).join("/"); + return { platform, model: modelName }; +} diff --git a/lib/params.ts b/lib/params.ts index 1c20028..01b75c7 100644 --- a/lib/params.ts +++ b/lib/params.ts @@ -1,4 +1,5 @@ import { parse } from "https://deno.land/std/flags/mod.ts"; + export type Params = { version: boolean; help: boolean; @@ -43,6 +44,7 @@ export function parseArgs(): Params { "max-tokens": 1000, }, }); + return { // boolean option version: args.v || args.version || false, diff --git a/test/llm_test.ts b/test/llm_test.ts index 0d576f8..acda447 100644 --- a/test/llm_test.ts +++ b/test/llm_test.ts @@ -5,6 +5,8 @@ import { ChatOpenAI } from "npm:@langchain/openai"; import { ChatAnthropic } from "npm:@langchain/anthropic"; import { ChatGoogleGenerativeAI } from "npm:@langchain/google-genai"; import { ChatOllama } from "npm:@langchain/community/chat_models/ollama"; +import { ChatGroq } from "npm:@langchain/groq"; +import { ChatTogetherAI } from "npm:@langchain/community/chat_models/togetherai"; import Replicate from "npm:replicate"; import { AIMessage, @@ -12,7 +14,12 @@ import { SystemMessage, } from "npm:@langchain/core/messages"; -import { generatePrompt, LLM } from "../lib/llm.ts"; +import { + generatePrompt, + LLM, + parsePlatform, + platformList, +} from "../lib/llm.ts"; Deno.test("Should create a ChatOpenAI instance for a GPT model", () => { Deno.env.set("OPENAI_API_KEY", "sk-11111"); @@ -74,8 +81,9 @@ Deno.test("Should create a ChatOllama instance for an Ollama model", () => { help: false, noChat: false, debug: false, - model: "llama:7b-chat", + model: "ollama/llama:7b-chat", url: "http://yourIP:11434", + platform: "ollama", temperature: 0.7, maxTokens: 2048, }; @@ -92,7 +100,7 @@ Deno.test("Should create a Replicate instance for an Replicate model", () => { help: false, noChat: false, debug: false, - model: "meta/llama2:7b-chat", + model: "replicate/meta/llama2:7b-chat", url: undefined, temperature: 0.7, maxTokens: 2048, @@ -104,6 +112,47 @@ Deno.test("Should create a Replicate instance for an Replicate model", () => { ); }); +Deno.test("Should create a Groq instance for an Groq model", () => { + Deno.env.set("GROQ_API_KEY", "sk-11111"); + const params = { + version: false, + help: false, + noChat: false, + debug: false, + model: "groq/llama-3.3-70b-versatile", + url: undefined, + temperature: 0.7, + maxTokens: 2048, + }; + const llm = new LLM(params); + assert( + llm.transrator instanceof ChatGroq, + `Expected LLM instance to be ChatGroq, but got ${llm.constructor.name}`, + ); + assertEquals(llm.transrator.model, "llama-3.3-70b-versatile"); + assertEquals(llm.transrator.temperature, 0.7); +}); + +Deno.test("Should create a TogetherAI instance for an TogetherAI model", () => { + Deno.env.set("TOGETHER_AI_API_KEY", "sk-11111"); + const params = { + version: false, + help: false, + noChat: false, + debug: false, + model: "togetherai/google/gemma-2-27b-it", + url: undefined, + temperature: 0.7, + maxTokens: 2048, + }; + const llm = new LLM(params); + assert( + llm.transrator instanceof ChatTogetherAI, + `Expected LLM instance to be ChatTogetherAI, but got ${llm.constructor.name}`, + ); + assertEquals(llm.transrator.model, "google/gemma-2-27b-it"); +}); + Deno.test("Should throw an error for an unknown model", () => { const params = { version: false, @@ -114,7 +163,11 @@ Deno.test("Should throw an error for an unknown model", () => { temperature: 0.7, maxTokens: 2048, }; - assertThrows(() => new LLM(params), Error, 'model not found "unknown-model"'); + assertThrows( + () => new LLM(params), + Error, + `unknown platform "", choose from ${platformList.join(", ")}`, + ); }); Deno.test("Replicate prompt generator", () => { @@ -155,3 +208,27 @@ hello, how can I help you? I have no name, just an AI`, ); }); + +Deno.test("parsePlatform - valid model string", () => { + const { platform, model } = parsePlatform("replicate/meta/llama3.3-70b"); + assertEquals(platform, "replicate"); + assertEquals(model, "meta/llama3.3-70b"); +}); + +Deno.test("parsePlatform - model string with only one part", () => { + const { platform, model } = parsePlatform("modelonly"); + assertEquals(platform, ""); + assertEquals(model, "modelonly"); +}); + +Deno.test("parsePlatform - model string with multiple slashes", () => { + const { platform, model } = parsePlatform("a/b/c/d"); + assertEquals(platform, "a"); + assertEquals(model, "b/c/d"); +}); + +Deno.test("parsePlatform - model string starts with slash", () => { + const { platform, model } = parsePlatform("/a/b/c"); + assertEquals(platform, ""); + assertEquals(model, "a/b/c"); +}); diff --git a/lib/params_test.ts b/test/params_test.ts similarity index 100% rename from lib/params_test.ts rename to test/params_test.ts