diff --git a/autogpt/.env.template b/autogpt/.env.template index 9d458b90a55d..8d4894988c53 100644 --- a/autogpt/.env.template +++ b/autogpt/.env.template @@ -11,6 +11,9 @@ ## GROQ_API_KEY - Groq API Key (Example: gsk_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx) # GROQ_API_KEY= +## LLAMAFILE_API_BASE - Llamafile API base URL +# LLAMAFILE_API_BASE=http://localhost:8080/v1 + ## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry. ## This helps us to spot and solve problems earlier & faster. (Default: DISABLED) # TELEMETRY_OPT_IN=true diff --git a/autogpt/scripts/llamafile/.gitignore b/autogpt/scripts/llamafile/.gitignore new file mode 100644 index 000000000000..3aa496945a5e --- /dev/null +++ b/autogpt/scripts/llamafile/.gitignore @@ -0,0 +1,3 @@ +*.llamafile +*.llamafile.exe +llamafile.exe diff --git a/autogpt/scripts/llamafile/serve.py b/autogpt/scripts/llamafile/serve.py new file mode 100755 index 000000000000..431fa253f19f --- /dev/null +++ b/autogpt/scripts/llamafile/serve.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Use llamafile to serve a (quantized) mistral-7b-instruct-v0.2 model + +Usage: + cd /autogpt + ./scripts/llamafile/serve.py +""" + +import os +import platform +import subprocess +from pathlib import Path +from typing import Optional + +import click + +LLAMAFILE = Path("mistral-7b-instruct-v0.2.Q5_K_M.llamafile") +LLAMAFILE_URL = f"https://huggingface.co/jartine/Mistral-7B-Instruct-v0.2-llamafile/resolve/main/{LLAMAFILE.name}" # noqa +LLAMAFILE_EXE = Path("llamafile.exe") +LLAMAFILE_EXE_URL = "https://github.com/Mozilla-Ocho/llamafile/releases/download/0.8.6/llamafile-0.8.6" # noqa + + +@click.command() +@click.option( + "--llamafile", + type=click.Path(dir_okay=False, path_type=Path), + help=f"Name of the llamafile to serve. Default: {LLAMAFILE.name}", +) +@click.option("--llamafile_url", help="Download URL for the llamafile you want to use") +@click.option( + "--host", help="Specify the address for the llamafile server to listen on" +) +@click.option( + "--port", type=int, help="Specify the port for the llamafile server to listen on" +) +@click.option( + "--force-gpu", + is_flag=True, + hidden=platform.system() != "Darwin", + help="Run the model using only the GPU (AMD or Nvidia). " + "Otherwise, both CPU and GPU may be (partially) used.", +) +def main( + llamafile: Optional[Path] = None, + llamafile_url: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + force_gpu: bool = False, +): + print(f"type(llamafile) = {type(llamafile)}") + if not llamafile: + if not llamafile_url: + llamafile = LLAMAFILE + else: + llamafile = Path(llamafile_url.rsplit("/", 1)[1]) + if llamafile.suffix != ".llamafile": + click.echo( + click.style( + "The given URL does not end with '.llamafile' -> " + "can't get filename from URL. " + "Specify the filename using --llamafile.", + fg="red", + ), + err=True, + ) + return + + if llamafile == LLAMAFILE and not llamafile_url: + llamafile_url = LLAMAFILE_URL + elif llamafile_url != LLAMAFILE_URL: + if not click.prompt( + click.style( + "You seem to have specified a different URL for the default model " + f"({llamafile.name}). Are you sure this is correct? " + "If you want to use a different model, also specify --llamafile.", + fg="yellow", + ), + type=bool, + ): + return + + # Go to autogpt/scripts/llamafile/ + os.chdir(Path(__file__).resolve().parent) + + on_windows = platform.system() == "Windows" + + if not llamafile.is_file(): + if not llamafile_url: + click.echo( + click.style( + "Please use --lamafile_url to specify a download URL for " + f"'{llamafile.name}'. " + "This will only be necessary once, so we can download the model.", + fg="red", + ), + err=True, + ) + return + + download_file(llamafile_url, llamafile) + + if not on_windows: + llamafile.chmod(0o755) + subprocess.run([llamafile, "--version"], check=True) + + if not on_windows: + base_command = [f"./{llamafile}"] + else: + # Windows does not allow executables over 4GB, so we have to download a + # model-less llamafile.exe and run that instead. + if not LLAMAFILE_EXE.is_file(): + download_file(LLAMAFILE_EXE_URL, LLAMAFILE_EXE) + LLAMAFILE_EXE.chmod(0o755) + subprocess.run([f".\\{LLAMAFILE_EXE}", "--version"], check=True) + + base_command = [f".\\{LLAMAFILE_EXE}", "-m", llamafile] + + if host: + base_command.extend(["--host", host]) + if port: + base_command.extend(["--port", str(port)]) + if force_gpu: + base_command.extend(["-ngl", "9999"]) + + subprocess.run( + [ + *base_command, + "--server", + "--nobrowser", + "--ctx-size", + "0", + "--n-predict", + "1024", + ], + check=True, + ) + + # note: --ctx-size 0 means the prompt context size will be set directly from the + # underlying model configuration. This may cause slow response times or consume + # a lot of memory. + + +def download_file(url: str, to_file: Path) -> None: + print(f"Downloading {to_file.name}...") + import urllib.request + + urllib.request.urlretrieve(url, to_file, reporthook=report_download_progress) + print() + + +def report_download_progress(chunk_number: int, chunk_size: int, total_size: int): + if total_size != -1: + downloaded_size = chunk_number * chunk_size + percent = min(1, downloaded_size / total_size) + bar = "#" * int(40 * percent) + print( + f"\rDownloading: [{bar:<40}] {percent:.0%}" + f" - {downloaded_size/1e6:.1f}/{total_size/1e6:.1f} MB", + end="", + ) + + +if __name__ == "__main__": + main() diff --git a/docs/content/AutoGPT/configuration/options.md b/docs/content/AutoGPT/configuration/options.md index 49e111330f00..ed0f70885d7d 100644 --- a/docs/content/AutoGPT/configuration/options.md +++ b/docs/content/AutoGPT/configuration/options.md @@ -22,6 +22,7 @@ You can set configuration variables via the `.env` file. If you don't have a `.e - `GROQ_API_KEY`: Set this if you want to use Groq models with AutoGPT - `HUGGINGFACE_API_TOKEN`: HuggingFace API, to be used for both image generation and audio to text. Optional. - `HUGGINGFACE_IMAGE_MODEL`: HuggingFace model to use for image generation. Default: CompVis/stable-diffusion-v1-4 +- `LLAMAFILE_API_BASE`: Llamafile API base URL. Default: `http://localhost:8080/v1` - `OPENAI_API_KEY`: Set this if you want to use OpenAI models; [OpenAI API Key](https://platform.openai.com/account/api-keys). - `OPENAI_ORGANIZATION`: Organization ID in OpenAI. Optional. - `PLAIN_OUTPUT`: Plain output, which disables the spinner. Default: False diff --git a/docs/content/AutoGPT/setup/index.md b/docs/content/AutoGPT/setup/index.md index 4fcdbf1cac44..78b9632ac19f 100644 --- a/docs/content/AutoGPT/setup/index.md +++ b/docs/content/AutoGPT/setup/index.md @@ -198,3 +198,66 @@ If you don't know which to choose, you can safely go with OpenAI*. [groq/api-keys]: https://console.groq.com/keys [groq/models]: https://console.groq.com/docs/models + + +### Llamafile + +With llamafile you can run models locally, which means no need to set up billing, +and guaranteed data privacy. + +For more information and in-depth documentation, check out the [llamafile documentation]. + +!!! warning + At the moment, llamafile only serves one model at a time. This means you can not + set `SMART_LLM` and `FAST_LLM` to two different llamafile models. + +!!! warning + Due to the issues linked below, llamafiles don't work on WSL. To use a llamafile + with AutoGPT in WSL, you will have to run the llamafile in Windows (outside WSL). + +
+ Instructions + + 1. Get the `llamafile/serve.py` script through one of these two ways: + 1. Clone the AutoGPT repo somewhere in your Windows environment, + with the script located at `autogpt/scripts/llamafile/serve.py` + 2. Download just the [serve.py] script somewhere in your Windows environment + 2. Make sure you have `click` installed: `pip install click` + 3. Run `ip route | grep default | awk '{print $3}'` *inside WSL* to get the address + of the WSL host machine + 4. Run `python3 serve.py --host {WSL_HOST_ADDR}`, where `{WSL_HOST_ADDR}` + is the address you found at step 3. + If port 8080 is taken, also specify a different port using `--port {PORT}`. + 5. In WSL, set `LLAMAFILE_API_BASE=http://{WSL_HOST_ADDR}:8080/v1` in your `.env`. + 6. Follow the rest of the regular instructions below. + + [serve.py]: https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt/scripts/llamafile/serve.py +
+ + * [Mozilla-Ocho/llamafile#356](https://github.com/Mozilla-Ocho/llamafile/issues/356) + * [Mozilla-Ocho/llamafile#100](https://github.com/Mozilla-Ocho/llamafile/issues/100) + +!!! note + These instructions will download and use `mistral-7b-instruct-v0.2.Q5_K_M.llamafile`. + `mistral-7b-instruct-v0.2` is currently the only tested and supported model. + If you want to try other models, you'll have to add them to `LlamafileModelName` in + [`llamafile.py`][forge/llamafile.py]. + For optimal results, you may also have to add some logic to adapt the message format, + like `LlamafileProvider._adapt_chat_messages_for_mistral_instruct(..)` does. + +1. Run the llamafile serve script: + ```shell + python3 ./scripts/llamafile/serve.py + ``` + The first time this is run, it will download a file containing the model + runtime, + which may take a while and a few gigabytes of disk space. + + To force GPU acceleration, add `--use-gpu` to the command. + +3. In `.env`, set `SMART_LLM`/`FAST_LLM` or both to `mistral-7b-instruct-v0.2` + +4. If the server is running on different address than `http://localhost:8080/v1`, + set `LLAMAFILE_API_BASE` in `.env` to the right base URL + +[llamafile documentation]: https://github.com/Mozilla-Ocho/llamafile#readme +[forge/llamafile.py]: https://github.com/Significant-Gravitas/AutoGPT/blob/master/forge/forge/llm/providers/llamafile/llamafile.py diff --git a/forge/forge/llm/providers/llamafile/README.md b/forge/forge/llm/providers/llamafile/README.md new file mode 100644 index 000000000000..e4276ec1a941 --- /dev/null +++ b/forge/forge/llm/providers/llamafile/README.md @@ -0,0 +1,36 @@ +# Llamafile Integration Notes + +Tested with: +* Python 3.11 +* Apple M2 Pro (32 GB), macOS 14.2.1 +* quantized mistral-7b-instruct-v0.2 + +## Setup + +Download a `mistral-7b-instruct-v0.2` llamafile: +```shell +wget -nc https://huggingface.co/jartine/Mistral-7B-Instruct-v0.2-llamafile/resolve/main/mistral-7b-instruct-v0.2.Q5_K_M.llamafile +chmod +x mistral-7b-instruct-v0.2.Q5_K_M.llamafile +./mistral-7b-instruct-v0.2.Q5_K_M.llamafile --version +``` + +Run the llamafile server: +```shell +LLAMAFILE="./mistral-7b-instruct-v0.2.Q5_K_M.llamafile" + +"${LLAMAFILE}" \ +--server \ +--nobrowser \ +--ctx-size 0 \ +--n-predict 1024 + +# note: ctx-size=0 means the prompt context size will be set directly from the +# underlying model configuration. This may cause slow response times or consume +# a lot of memory. +``` + +## TODOs + +* `SMART_LLM`/`FAST_LLM` configuration: Currently, the llamafile server only serves one model at a time. However, there's no reason you can't start multiple llamafile servers on different ports. To support using different models for `smart_llm` and `fast_llm`, you could implement config vars like `LLAMAFILE_SMART_LLM_URL` and `LLAMAFILE_FAST_LLM_URL` that point to different llamafile servers (one serving a 'big model' and one serving a 'fast model'). +* Authorization: the `serve.sh` script does not set up any authorization for the llamafile server; this can be turned on by adding arg `--api-key ` to the server startup command. However I haven't attempted to test whether the integration with autogpt works when this feature is turned on. +* Test with other models diff --git a/forge/forge/llm/providers/llamafile/__init__.py b/forge/forge/llm/providers/llamafile/__init__.py new file mode 100644 index 000000000000..23706b102be6 --- /dev/null +++ b/forge/forge/llm/providers/llamafile/__init__.py @@ -0,0 +1,17 @@ +from .llamafile import ( + LLAMAFILE_CHAT_MODELS, + LLAMAFILE_EMBEDDING_MODELS, + LlamafileCredentials, + LlamafileModelName, + LlamafileProvider, + LlamafileSettings, +) + +__all__ = [ + "LLAMAFILE_CHAT_MODELS", + "LLAMAFILE_EMBEDDING_MODELS", + "LlamafileCredentials", + "LlamafileModelName", + "LlamafileProvider", + "LlamafileSettings", +] diff --git a/forge/forge/llm/providers/llamafile/llamafile.py b/forge/forge/llm/providers/llamafile/llamafile.py new file mode 100644 index 000000000000..7eb7afaafb10 --- /dev/null +++ b/forge/forge/llm/providers/llamafile/llamafile.py @@ -0,0 +1,351 @@ +import enum +import logging +import re +from pathlib import Path +from typing import Any, Iterator, Optional, Sequence + +import requests +from openai.types.chat import ( + ChatCompletionMessage, + ChatCompletionMessageParam, + CompletionCreateParams, +) +from pydantic import SecretStr + +from forge.json.parsing import json_loads +from forge.models.config import UserConfigurable + +from .._openai_base import BaseOpenAIChatProvider +from ..schema import ( + AssistantToolCall, + AssistantToolCallDict, + ChatMessage, + ChatModelInfo, + CompletionModelFunction, + ModelProviderConfiguration, + ModelProviderCredentials, + ModelProviderName, + ModelProviderSettings, + ModelTokenizer, +) + + +class LlamafileModelName(str, enum.Enum): + MISTRAL_7B_INSTRUCT = "mistral-7b-instruct-v0.2" + + +LLAMAFILE_CHAT_MODELS = { + info.name: info + for info in [ + ChatModelInfo( + name=LlamafileModelName.MISTRAL_7B_INSTRUCT, + provider_name=ModelProviderName.LLAMAFILE, + prompt_token_cost=0.0, + completion_token_cost=0.0, + max_tokens=32768, + has_function_call_api=False, + ), + ] +} + +LLAMAFILE_EMBEDDING_MODELS = {} + + +class LlamafileConfiguration(ModelProviderConfiguration): + # TODO: implement 'seed' across forge.llm.providers + seed: Optional[int] = None + + +class LlamafileCredentials(ModelProviderCredentials): + api_key: Optional[SecretStr] = SecretStr("sk-no-key-required") + api_base: SecretStr = UserConfigurable( # type: ignore + default=SecretStr("http://localhost:8080/v1"), from_env="LLAMAFILE_API_BASE" + ) + + def get_api_access_kwargs(self) -> dict[str, str]: + return { + k: v.get_secret_value() + for k, v in { + "api_key": self.api_key, + "base_url": self.api_base, + }.items() + if v is not None + } + + +class LlamafileSettings(ModelProviderSettings): + configuration: LlamafileConfiguration # type: ignore + credentials: Optional[LlamafileCredentials] = None # type: ignore + + +class LlamafileTokenizer(ModelTokenizer[int]): + def __init__(self, credentials: LlamafileCredentials): + self._credentials = credentials + + @property + def _tokenizer_base_url(self): + # The OpenAI-chat-compatible base url should look something like + # 'http://localhost:8080/v1' but the tokenizer endpoint is + # 'http://localhost:8080/tokenize'. So here we just strip off the '/v1'. + api_base = self._credentials.api_base.get_secret_value() + return api_base.strip("/v1") + + def encode(self, text: str) -> list[int]: + response = requests.post( + url=f"{self._tokenizer_base_url}/tokenize", json={"content": text} + ) + response.raise_for_status() + return response.json()["tokens"] + + def decode(self, tokens: list[int]) -> str: + response = requests.post( + url=f"{self._tokenizer_base_url}/detokenize", json={"tokens": tokens} + ) + response.raise_for_status() + return response.json()["content"] + + +class LlamafileProvider( + BaseOpenAIChatProvider[LlamafileModelName, LlamafileSettings], + # TODO: add and test support for embedding models + # BaseOpenAIEmbeddingProvider[LlamafileModelName, LlamafileSettings], +): + EMBEDDING_MODELS = LLAMAFILE_EMBEDDING_MODELS + CHAT_MODELS = LLAMAFILE_CHAT_MODELS + MODELS = {**CHAT_MODELS, **EMBEDDING_MODELS} + + default_settings = LlamafileSettings( + name="llamafile_provider", + description=( + "Provides chat completion and embedding services " + "through a llamafile instance" + ), + configuration=LlamafileConfiguration(), + ) + + _settings: LlamafileSettings + _credentials: LlamafileCredentials + _configuration: LlamafileConfiguration + + async def get_available_models(self) -> Sequence[ChatModelInfo[LlamafileModelName]]: + _models = (await self._client.models.list()).data + # note: at the moment, llamafile only serves one model at a time (so this + # list will only ever have one value). however, in the future, llamafile + # may support multiple models, so leaving this method as-is for now. + self._logger.debug(f"Retrieved llamafile models: {_models}") + + clean_model_ids = [clean_model_name(m.id) for m in _models] + self._logger.debug(f"Cleaned llamafile model IDs: {clean_model_ids}") + + return [ + LLAMAFILE_CHAT_MODELS[id] + for id in clean_model_ids + if id in LLAMAFILE_CHAT_MODELS + ] + + def get_tokenizer(self, model_name: LlamafileModelName) -> LlamafileTokenizer: + return LlamafileTokenizer(self._credentials) + + def count_message_tokens( + self, + messages: ChatMessage | list[ChatMessage], + model_name: LlamafileModelName, + ) -> int: + if isinstance(messages, ChatMessage): + messages = [messages] + + if model_name == LlamafileModelName.MISTRAL_7B_INSTRUCT: + # For mistral-instruct, num added tokens depends on if the message + # is a prompt/instruction or an assistant-generated message. + # - prompt gets [INST], [/INST] added and the first instruction + # begins with '' ('beginning-of-sentence' token). + # - assistant-generated messages get '' added + # see: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2 + # + prompt_added = 1 # one for '' token + assistant_num_added = 0 + ntokens = 0 + for message in messages: + if ( + message.role == ChatMessage.Role.USER + # note that 'system' messages will get converted + # to 'user' messages before being sent to the model + or message.role == ChatMessage.Role.SYSTEM + ): + # 5 tokens for [INST], [/INST], which actually get + # tokenized into "[, INST, ]" and "[, /, INST, ]" + # by the mistral tokenizer + prompt_added += 5 + elif message.role == ChatMessage.Role.ASSISTANT: + assistant_num_added += 1 # for + else: + raise ValueError( + f"{model_name} does not support role: {message.role}" + ) + + ntokens += self.count_tokens(message.content, model_name) + + total_token_count = prompt_added + assistant_num_added + ntokens + return total_token_count + + else: + raise NotImplementedError( + f"count_message_tokens not implemented for model {model_name}" + ) + + def _get_chat_completion_args( + self, + prompt_messages: list[ChatMessage], + model: LlamafileModelName, + functions: list[CompletionModelFunction] | None = None, + max_output_tokens: int | None = None, + **kwargs, + ) -> tuple[ + list[ChatCompletionMessageParam], CompletionCreateParams, dict[str, Any] + ]: + messages, completion_kwargs, parse_kwargs = super()._get_chat_completion_args( + prompt_messages, model, functions, max_output_tokens, **kwargs + ) + + if model == LlamafileModelName.MISTRAL_7B_INSTRUCT: + messages = self._adapt_chat_messages_for_mistral_instruct(messages) + + if "seed" not in kwargs and self._configuration.seed is not None: + completion_kwargs["seed"] = self._configuration.seed + + # Convert all messages with content blocks to simple text messages + for message in messages: + if isinstance(content := message.get("content"), list): + message["content"] = "\n\n".join( + b["text"] + for b in content + if b["type"] == "text" + # FIXME: add support for images through image_data completion kwarg + ) + + return messages, completion_kwargs, parse_kwargs + + def _adapt_chat_messages_for_mistral_instruct( + self, messages: list[ChatCompletionMessageParam] + ) -> list[ChatCompletionMessageParam]: + """ + Munge the messages to be compatible with the mistral-7b-instruct chat + template, which: + - only supports 'user' and 'assistant' roles. + - expects messages to alternate between user/assistant roles. + + See details here: + https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format + """ + adapted_messages: list[ChatCompletionMessageParam] = [] + for message in messages: + # convert 'system' role to 'user' role as mistral-7b-instruct does + # not support 'system' + if message["role"] == ChatMessage.Role.SYSTEM: + message["role"] = ChatMessage.Role.USER + + if ( + len(adapted_messages) == 0 + or message["role"] != (last_message := adapted_messages[-1])["role"] + ): + adapted_messages.append(message) + else: + if not message.get("content"): + continue + + # if the curr message has the same role as the previous one, + # concat the current message content to the prev message + if message["role"] == "user" and last_message["role"] == "user": + # user messages can contain other types of content blocks + if not isinstance(last_message["content"], list): + last_message["content"] = [ + {"type": "text", "text": last_message["content"]} + ] + + last_message["content"].extend( + message["content"] + if isinstance(message["content"], list) + else [{"type": "text", "text": message["content"]}] + ) + elif message["role"] != "user" and last_message["role"] != "user": + last_message["content"] = ( + (last_message.get("content") or "") + + "\n\n" + + (message.get("content") or "") + ).strip() + + return adapted_messages + + def _parse_assistant_tool_calls( + self, + assistant_message: ChatCompletionMessage, + compat_mode: bool = False, + **kwargs, + ): + tool_calls: list[AssistantToolCall] = [] + parse_errors: list[Exception] = [] + + if compat_mode and assistant_message.content: + try: + tool_calls = list( + _tool_calls_compat_extract_calls(assistant_message.content) + ) + except Exception as e: + parse_errors.append(e) + + return tool_calls, parse_errors + + +def clean_model_name(model_file: str) -> str: + """ + Clean up model names: + 1. Remove file extension + 2. Remove quantization info + + Examples: + ``` + raw: 'mistral-7b-instruct-v0.2.Q5_K_M.gguf' + clean: 'mistral-7b-instruct-v0.2' + + raw: '/Users/kate/models/mistral-7b-instruct-v0.2.Q5_K_M.gguf' + clean: 'mistral-7b-instruct-v0.2' + + raw: 'llava-v1.5-7b-q4.gguf' + clean: 'llava-v1.5-7b' + ``` + """ + name_without_ext = Path(model_file).name.rsplit(".", 1)[0] + name_without_Q = re.match( + r"^[a-zA-Z0-9]+([.\-](?!([qQ]|B?F)\d{1,2})[a-zA-Z0-9]+)*", + name_without_ext, + ) + return name_without_Q.group() if name_without_Q else name_without_ext + + +def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]: + import re + import uuid + + logging.debug(f"Trying to extract tool calls from response:\n{response}") + + response = response.strip() # strip off any leading/trailing whitespace + if response.startswith("```"): + # attempt to remove any extraneous markdown artifacts like "```json" + response = response.strip("```") + if response.startswith("json"): + response = response.strip("json") + response = response.strip() # any remaining whitespace + + if response[0] == "[": + tool_calls: list[AssistantToolCallDict] = json_loads(response) + else: + block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL) + if not block: + raise ValueError("Could not find tool_calls block in response") + tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1)) + + for t in tool_calls: + t["id"] = str(uuid.uuid4()) + # t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK + + yield AssistantToolCall.parse_obj(t) diff --git a/forge/forge/llm/providers/multi.py b/forge/forge/llm/providers/multi.py index 0606511043c2..e0b08352299b 100644 --- a/forge/forge/llm/providers/multi.py +++ b/forge/forge/llm/providers/multi.py @@ -7,6 +7,7 @@ from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider from .groq import GROQ_CHAT_MODELS, GroqModelName, GroqProvider +from .llamafile import LLAMAFILE_CHAT_MODELS, LlamafileModelName, LlamafileProvider from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider from .schema import ( AssistantChatMessage, @@ -24,10 +25,15 @@ _T = TypeVar("_T") -ModelName = AnthropicModelName | GroqModelName | OpenAIModelName +ModelName = AnthropicModelName | GroqModelName | LlamafileModelName | OpenAIModelName EmbeddingModelProvider = OpenAIProvider -CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **GROQ_CHAT_MODELS, **OPEN_AI_CHAT_MODELS} +CHAT_MODELS = { + **ANTHROPIC_CHAT_MODELS, + **GROQ_CHAT_MODELS, + **LLAMAFILE_CHAT_MODELS, + **OPEN_AI_CHAT_MODELS, +} class MultiProvider(BaseChatModelProvider[ModelName, ModelProviderSettings]): @@ -116,35 +122,52 @@ def get_model_provider(self, model: ModelName) -> ChatModelProvider: def get_available_providers(self) -> Iterator[ChatModelProvider]: for provider_name in ModelProviderName: + self._logger.debug(f"Checking if {provider_name} is available...") try: yield self._get_provider(provider_name) - except Exception: + self._logger.debug(f"{provider_name} is available!") + except ValueError: pass def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider: _provider = self._provider_instances.get(provider_name) if not _provider: Provider = self._get_provider_class(provider_name) + self._logger.debug( + f"{Provider.__name__} not yet in cache, trying to init..." + ) + settings = Provider.default_settings.model_copy(deep=True) settings.budget = self._budget settings.configuration.extra_request_headers.update( self._settings.configuration.extra_request_headers ) if settings.credentials is None: + credentials_field = settings.model_fields["credentials"] + Credentials = get_args( # Union[Credentials, None] -> Credentials + credentials_field.annotation + )[0] + self._logger.debug(f"Loading {Credentials.__name__}...") try: - Credentials = get_args( # Union[Credentials, None] -> Credentials - settings.model_fields["credentials"].annotation - )[0] settings.credentials = Credentials.from_env() except ValidationError as e: - raise ValueError( - f"{provider_name} is unavailable: can't load credentials" - ) from e + if credentials_field.is_required(): + self._logger.debug( + f"Could not load (required) {Credentials.__name__}" + ) + raise ValueError( + f"{Provider.__name__} is unavailable: " + "can't load credentials" + ) from e + self._logger.debug( + f"Could not load {Credentials.__name__}, continuing without..." + ) self._provider_instances[provider_name] = _provider = Provider( settings=settings, logger=self._logger # type: ignore ) _provider._budget = self._budget # Object binding not preserved by Pydantic + self._logger.debug(f"Initialized {Provider.__name__}!") return _provider @classmethod @@ -155,6 +178,7 @@ def _get_provider_class( return { ModelProviderName.ANTHROPIC: AnthropicProvider, ModelProviderName.GROQ: GroqProvider, + ModelProviderName.LLAMAFILE: LlamafileProvider, ModelProviderName.OPENAI: OpenAIProvider, }[provider_name] except KeyError: @@ -164,4 +188,10 @@ def __repr__(self): return f"{self.__class__.__name__}()" -ChatModelProvider = AnthropicProvider | GroqProvider | OpenAIProvider | MultiProvider +ChatModelProvider = ( + AnthropicProvider + | GroqProvider + | LlamafileProvider + | OpenAIProvider + | MultiProvider +) diff --git a/forge/forge/llm/providers/schema.py b/forge/forge/llm/providers/schema.py index e4cc5ed5fd5c..2ca7b23e1de5 100644 --- a/forge/forge/llm/providers/schema.py +++ b/forge/forge/llm/providers/schema.py @@ -55,6 +55,7 @@ class ModelProviderName(str, enum.Enum): OPENAI = "openai" ANTHROPIC = "anthropic" GROQ = "groq" + LLAMAFILE = "llamafile" class ChatMessage(BaseModel):