From a8f05bb76aac0a113fd476016a474b6fe334a91b Mon Sep 17 00:00:00 2001 From: Mac0q Date: Fri, 26 Apr 2024 13:31:18 +0800 Subject: [PATCH] open source support --- model_worker/README.md | 55 +++- model_worker/custom_worker.py | 67 +++++ requirements.txt | 2 +- ufo/config/config.yaml.template | 18 ++ ufo/llm/base.py | 29 +- ufo/llm/cogagent.py | 81 +++++ ufo/llm/llava.py | 519 ++++++++++++++++++++++++++++++++ ufo/llm/llm_call.py | 2 +- 8 files changed, 764 insertions(+), 9 deletions(-) create mode 100644 model_worker/custom_worker.py create mode 100644 ufo/llm/cogagent.py create mode 100644 ufo/llm/llava.py diff --git a/model_worker/README.md b/model_worker/README.md index 7fe70efc..314fa54e 100644 --- a/model_worker/README.md +++ b/model_worker/README.md @@ -3,7 +3,7 @@ The lite version of the prompt is not fully optimized. To achieve better results ### If you use QWEN as the Agent 1. QWen (Tongyi Qianwen) is a LLM developed by Alibaba. Go to [QWen](https://dashscope.aliyun.com/) and register an account and get the API key. More details can be found [here](https://help.aliyun.com/zh/dashscope/developer-reference/activate-dashscope-and-create-an-api-key?spm=a2c4g.11186623.0.0.7b5749d72j3SYU) (in Chinese). -2. Install the required packages dashscope or run the `setup.py` with `-qwen` options. +2. Uncomment the required packages in requirements.txt or install them separately. ```bash pip install dashscope ``` @@ -23,7 +23,7 @@ You can find the model name in the [QWen LLM model list](https://help.aliyun.com We provide a short example to show how to configure the ollama in the following, which might change if ollama makes updates. ```bash title="install ollama and serve LLMs in local" showLineNumbers -## Install ollama on Linux & WSL2 or run the `setup.py` with `-ollama` options +## Install ollama on Linux & WSL2. curl https://ollama.ai/install.sh | sh ## Run the serving ollama serve @@ -45,7 +45,7 @@ When serving LLMs via Ollama, it will by default start a server at `http://local "API_MODEL": "YOUR_MODEL" } ``` -NOTE: `API_BASE` is the URL started in the Ollama LLM server and `API_MODEL` is the model name of Ollama LLM, it should be same as the one you served before. In addition, due to model limitations, you can use lite version of prompt to have a taste on UFO which can be configured in `config_dev.yaml`. Attention to the top ***note***. +NOTE: `API_BASE` is the URL started in the Ollama LLM server and `API_MODEL` is the model name of Ollama LLM, it should be same as the one you served before. In addition, due to model limitations, you can use lite version of prompt to have a taste on UFO which can be configured in `config_dev.yaml`. Attention to the top ***NOTE***. #### If you use your custom model as the Agent 1. Start a server with your model, which will later be used as the API base in `config.yaml`. @@ -53,11 +53,56 @@ NOTE: `API_BASE` is the URL started in the Ollama LLM server and `API_MODEL` is 2. Add following configuration to `config.yaml`: ```json showLineNumbers { - "API_TYPE": "custom_model" , + "API_TYPE": "Custom" , "API_BASE": "YOUR_ENDPOINT", "API_KEY": "YOUR_KEY", "API_MODEL": "YOUR_MODEL" } ``` -NOTE: You should create a new Python script .py in the ufo/llm folder like the format of the .py, which needs to inherit `BaseService` as the parent class, as well as the `__init__` and `chat_completion` methods. At the same time, you need to add the dynamic import of your file in the `get_service` method of `BaseService`. \ No newline at end of file +NOTE: You should create a new Python script `custom_model.py` in the ufo/llm folder like the format of the `placeholder.py`, which needs to inherit `BaseService` as the parent class, as well as the `__init__` and `chat_completion` methods. At the same time, you need to add the dynamic import of your file in the `get_service` method of `BaseService`. + +####EXAMPLE +Also, ufo provides the usage of ***LLaVA-1.5*** and ***CogAgent*** as the example. + +1.1 Download the essential libs of your custom model. + +#### If you use LLaVA-1.5 as the Agent + +Please refer to the [LLaVA](https://github.com/haotian-liu/LLaVA) project to download and prepare the LLaVA-1.5 model, for example: + +```bash +git clone https://github.com/haotian-liu/LLaVA.git +cd LLaVA +conda create -n llava python=3.10 -y +conda activate llava +pip install --upgrade pip # enable PEP 660 support +pip install -e . +``` + +#### If you use CogAgent as the Agent + +Please refer to the [CogVLM](https://github.com/THUDM/CogVLM) project to download and prepare the CogAgent model. Download the sat version of the CogAgent weights `cogagent-chat.zip` from [here](https://huggingface.co/THUDM/CogAgent/tree/main), unzip it. + +1.2 Start your custom model. You must customize your model to support the interface of the UFO. +For simplicity, you have to configure `YOUR_ENDPOINT/chat/completions`. + +#### If you use LLaVA as the Agent +Add the `direct_generate_llava` method and a new post interface `/chat/completions` from the `custom_model_worker.py` to the into the `llava/serve/model_worker.py` And start it with the following command: +```bash +python -m llava.serve.llava_model_worker --host YOUR_HOST --port YOUR_POINT --worker YOUR_ENDPOINT --model-path liuhaotian/llava-v1.5-13b --no-register +``` + +#### If you use CogAgent as the Agent +You can modify the model generate from the `basic_demo/cli_demo.py` with a new post interface `/chat/completions` to enjoy it with UFO. + +3. Add following configuration to `config.yaml`: +```json showLineNumbers +{ + "API_TYPE": "Custom" , + "API_BASE": "YOUR_ENDPOINT", + "API_MODEL": "YOUR_MODEL" +} +``` + +***Note***: Only LLaVA and CogAgent are supported as open source models for now. If you want to use your own model, remember to modify the dynamic import of your model file in the `get_service` method of `BaseService` in `ufo/llm/base.py`. \ No newline at end of file diff --git a/model_worker/custom_worker.py b/model_worker/custom_worker.py new file mode 100644 index 00000000..504c4341 --- /dev/null +++ b/model_worker/custom_worker.py @@ -0,0 +1,67 @@ +#Method to generate response from prompt and image using the Llava model +@torch.inference_mode() +def direct_generate_llava(self, params): + tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor + + prompt = params["prompt"] + image = params.get("image", None) + if image is not None: + if DEFAULT_IMAGE_TOKEN not in prompt: + raise ValueError("Number of image does not match number of tokens in prompt") + + image = load_image_from_base64(image) + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image = image.to(self.model.device, dtype=self.model.dtype) + images = image.unsqueeze(0) + + replace_token = DEFAULT_IMAGE_TOKEN + if getattr(self.model.config, 'mm_use_im_start_end', False): + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) + + num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches + else: + return {"text": "No image provided", "error_code": 0} + + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_context_length = getattr(model.config, 'max_position_embeddings', 2048) + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) + stop_str = params.get("stop", None) + do_sample = True if temperature > 0.001 else False + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) + keywords = [stop_str] + max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) + + input_seq_len = input_ids.shape[1] + + generation_output = self.model.generate( + inputs=input_ids, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + images=images, + use_cache=True, + ) + + generation_output = generation_output[0, input_seq_len:] + decoded = tokenizer.decode(generation_output, skip_special_tokens=True) + + response = {"text": decoded} + print("response", response) + return response + + +# The API is included in llava and cogagent installations. If you customize your model, you can install fastapi via pip or uncomment the library in the requirements. +# import FastAPI +# app = FastAPI() + +#For llava +@app.post("/chat/completions") +async def generate_llava(request: Request): + params = await request.json() + response_data = worker.direct_generate_llava(params) + return response_data \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 12065ba1..4c0b4218 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ langchain==0.1.11 langchain_community==0.0.27 msal==1.25.0 openai==1.13.3 -Pillow==10.2.0 +Pillow==10.3.0 pywin32==306 pywinauto==0.6.8 PyYAML==6.0.1 diff --git a/ufo/config/config.yaml.template b/ufo/config/config.yaml.template index 8b223019..867aba04 100644 --- a/ufo/config/config.yaml.template +++ b/ufo/config/config.yaml.template @@ -15,6 +15,12 @@ HOST_AGENT: { # API_VERSION: "2024-02-15-preview", # "2024-02-15-preview" by default # API_MODEL: "YOUR_MODEL", # The only OpenAI model by now that accepts visual input # API_DEPLOYMENT_ID: "gpt-4-visual-preview", # The deployment id for the AOAI API + + ### Comment above and uncomment these according to your need if using "Qwen", "Ollama" or "Custom". + # API_TYPE: "Custom", + # API_BASE: "YOUR_ENDPOINT", + # API_KEY: "YOUR_KEY", + # API_MODEL: "YOUR_MODEL", ### For Azure_AD # AAD_TENANT_ID: "YOUR_TENANT_ID", # Set the value to your tenant id for the llm model @@ -39,6 +45,12 @@ APP_AGENT: { # API_VERSION: "2024-02-15-preview", # "2024-02-15-preview" by default # API_MODEL: "YOUR_MODEL", # The only OpenAI model by now that accepts visual input # API_DEPLOYMENT_ID: "gpt-4-visual-preview", # The deployment id for the AOAI API + + ### Comment above and uncomment these according to your need if using "Qwen", "Ollama" or "Custom". + # API_TYPE: "Custom", + # API_BASE: "YOUR_ENDPOINT", + # API_KEY: "YOUR_KEY", + # API_MODEL: "YOUR_MODEL", ### For Azure_AD # AAD_TENANT_ID: "YOUR_TENANT_ID", # Set the value to your tenant id for the llm model @@ -63,6 +75,12 @@ BACKUP_AGENT: { # API_VERSION: "2024-02-15-preview", # "2024-02-15-preview" by default # API_MODEL: "YOUR_MODEL", # The only OpenAI model by now that accepts visual input # API_DEPLOYMENT_ID: "gpt-4-visual-preview", # The deployment id for the AOAI API + + ### Comment above and uncomment these according to your need if using "Qwen", "Ollama" or "Custom". + # API_TYPE: "Custom", + # API_BASE: "YOUR_ENDPOINT", + # API_KEY: "YOUR_KEY", + # API_MODEL: "YOUR_MODEL", ### For Azure_AD # AAD_TENANT_ID: "YOUR_TENANT_ID", # Set the value to your tenant id for the llm model diff --git a/ufo/llm/base.py b/ufo/llm/base.py index dbf30a2a..a2a0ad62 100644 --- a/ufo/llm/base.py +++ b/ufo/llm/base.py @@ -14,7 +14,17 @@ def chat_completion(self, *args, **kwargs): pass @staticmethod - def get_service(name): + def get_service(name, model_name=None): + """ + Get the service based on the given name and custom model. + Args: + name (str): The name of the service. + model_name (str, optional): The model name. + Returns: + object: The service object. + Raises: + ValueError: If the given service name or model name is not supported. + """ service_map = { 'openai': 'OpenAIService', 'aoai': 'OpenAIService', @@ -22,14 +32,29 @@ def get_service(name): 'qwen': 'QwenService', 'ollama': 'OllamaService', 'placeholder': 'PlaceHolderService', + 'custom': 'CustomService', } + custom_service_map = { + 'llava': 'LlavaService', + 'cogagent': 'CogAgentService', + } service_name = service_map.get(name, None) if service_name: if name in ['aoai', 'azure_ad']: module = import_module('.openai', package='ufo.llm') + elif service_name == 'CustomService': + custom_model = 'llava' if 'llava' in model_name else model_name + custom_service_name = custom_service_map.get('llava' if 'llava' in custom_model else custom_model, None) + if custom_service_name: + module = import_module('.'+custom_model, package='ufo.llm') + service_name = custom_service_name + else: + raise ValueError(f'Custom model {custom_model} not supported') else: module = import_module('.'+name.lower(), package='ufo.llm') - return getattr(module, service_name) + return getattr(module, service_name) + else: + raise ValueError(f'Model {name} not supported') def get_cost_estimator(self, api_type, model, prices, prompt_tokens, completion_tokens) -> float: """ diff --git a/ufo/llm/cogagent.py b/ufo/llm/cogagent.py new file mode 100644 index 00000000..f017419d --- /dev/null +++ b/ufo/llm/cogagent.py @@ -0,0 +1,81 @@ +import time +from typing import Any, Optional + +import requests + +from ufo.utils import print_with_color +from .base import BaseService + + +class CogAgentService(BaseService): + def __init__(self, config, agent_type: str): + self.config_llm = config[agent_type] + self.config = config + self.max_retry = self.config["MAX_RETRY"] + self.timeout = self.config["TIMEOUT"] + self.max_tokens = 2048 #default max tokens for cogagent for now + + def chat_completion( + self, + messages, + n, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs: Any, + ): + """ + Generate chat completions based on given messages. + Args: + messages (list): A list of messages. + n (int): The number of completions to generate. + temperature (float, optional): The temperature for sampling. Defaults to None. + max_tokens (int, optional): The maximum number of tokens in the completion. Defaults to None. + top_p (float, optional): The cumulative probability for top-p sampling. Defaults to None. + **kwargs: Additional keyword arguments. + Returns: + tuple: A tuple containing the generated texts and None. + """ + + temperature = temperature if temperature is not None else self.config["TEMPERATURE"] + max_tokens = max_tokens if max_tokens is not None else self.config["MAX_TOKENS"] + top_p = top_p if top_p is not None else self.config["TOP_P"] + + texts = [] + for i in range(n): + image_base64 = None + if self.config_llm["VISUAL_MODE"]: + image_base64 = messages[1]['content'][-2]['image_url']\ + ['url'].split('base64,')[1] + prompt = messages[0]['content'] + messages[1]['content'][-1]['text'] + + payload = { + 'model': self.config_llm['API_MODEL'], + 'prompt': prompt, + 'temperature': temperature, + 'top_p': top_p, + 'max_new_tokens': self.max_tokens, + "image":image_base64 + } + + for _ in range(self.max_retry): + try: + response = requests.post(self.config_llm['API_BASE']+"/chat/completions", json=payload) + if response.status_code == 200: + response = response.json() + text = response["text"] + texts.append(text) + break + else: + raise Exception( + f"Failed to get completion with error code {response.status_code}: {response.text}", + ) + except Exception as e: + print_with_color(f"Error making API request: {e}", "red") + try: + print_with_color(response, "red") + except: + _ + time.sleep(3) + continue + return texts, None \ No newline at end of file diff --git a/ufo/llm/llava.py b/ufo/llm/llava.py new file mode 100644 index 00000000..2f3666da --- /dev/null +++ b/ufo/llm/llava.py @@ -0,0 +1,519 @@ +import time +from typing import Any, Optional +import dataclasses +from enum import auto, Enum +from typing import List +import base64 +from io import BytesIO +from PIL import Image + +import requests +from ufo.utils import print_with_color +from .base import BaseService + +DEFAULT_IMAGE_TOKEN = "" + +class LlavaService(BaseService): + def __init__(self, config, agent_type: str): + self.config_llm = config[agent_type] + self.config = config + self.max_retry = self.config["MAX_RETRY"] + self.timeout = self.config["TIMEOUT"] + self.max_tokens = 2048 #default max tokens for llava for now + + def chat_completion( + self, + messages, + n, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + top_p: Optional[float] = None, + **kwargs: Any, + ): + """ + Generates chat completions based on the given messages. + Args: + messages (list): A list of messages. + n (int): The number of completions to generate. + temperature (float, optional): The temperature value for controlling the randomness of the completions. Defaults to None. + max_tokens (int, optional): The maximum number of tokens in the completions. Defaults to None. + top_p (float, optional): The cumulative probability for selecting the next token in the completions. Defaults to None. + **kwargs: Additional keyword arguments. + Returns: + tuple: A tuple containing the generated texts and None. + Raises: + Exception: If there is an error in the API request. + """ + temperature = temperature if temperature is not None else self.config["TEMPERATURE"] + max_tokens = max_tokens if max_tokens is not None else self.config["MAX_TOKENS"] + top_p = top_p if top_p is not None else self.config["TOP_P"] + conv = conv_templates[self._conversation()].copy() + + texts = [] + for i in range(n): + if self.config_llm["VISUAL_MODE"]: + inp = DEFAULT_IMAGE_TOKEN + '\n' + messages[1]['content'][-1]['text'] + conv.append_message(conv.roles[0], inp) + image_base64 = messages[1]['content'][-2]['image_url']\ + ['url'].split('base64,')[1] + else: + conv.append_message(conv.roles[0], messages[1]['content'][-1]['text']) + prompt = conv.get_prompt() + + payload = { + 'model': self.config_llm['API_MODEL'], + 'prompt': prompt, + 'temperature': temperature, + 'top_p': top_p, + 'max_new_tokens': self.max_tokens, + "image":image_base64 + } + + for _ in range(self.max_retry): + try: + response = requests.post(self.config_llm['API_BASE']+"/chat/completions", json=payload, timeout=self.timeout) + if response.status_code == 200: + response = response.json() + text = response["text"] + texts.append(text) + break + else: + raise Exception( + f"Failed to get completion with error code {response.status_code}: {response.text}", + ) + except Exception as e: + print_with_color(f"Error making API request: {e}", "red") + try: + print_with_color(response, "red") + except: + _ + time.sleep(3) + continue + return texts, None + + + + def _conversation(self): + """ + Determines the conversation mode based on the model name. + Returns: + str: The conversation mode based on the model name. + """ + model_paths = self.config_llm["API_MODEL"].strip("/").split("/") + model_name = model_paths[-2] + "_" + model_paths[-1] if model_paths[-1].startswith('checkpoint-') else model_paths[-1] + if "llama-2" in model_name.lower(): + conv_mode = "llava_llama_2" + elif "mistral" in model_name.lower(): + conv_mode = "mistral_instruct" + elif "v1.6-34b" in model_name.lower(): + conv_mode = "chatml_direct" + elif "v1" in model_name.lower(): + conv_mode = "llava_v1" + elif "mpt" in model_name.lower(): + conv_mode = "mpt" + else: + conv_mode = "vicuna_v1" + return conv_mode + + + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + MPT = auto() + PLAIN = auto() + LLAMA_2 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + """ + Generates a prompt message based on the current state of the conversation. + Returns: + str: The generated prompt message. + """ + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + init_msg = init_msg[0].replace("", "").strip() + if 'mmtag' in self.version: + messages[0] = (init_role, init_msg) + messages.insert(0, (self.roles[0], "")) + messages.insert(1, (self.roles[1], "Received.")) + else: + messages[0] = (init_role, "\n" + init_msg) + + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): + """ + Process the given image based on the specified image_process_mode. + Args: + image (PIL.Image.Image): The input image to be processed. + image_process_mode (str): The mode for processing the image. Possible values are 'Pad', 'Default', 'Crop', or 'Resize'. + return_pil (bool, optional): Whether to return the processed image as a PIL Image object. Defaults to False. + image_format (str, optional): The format to save the image in. Defaults to 'PNG'. + max_len (int, optional): The maximum length of the image's longest edge. Defaults to 1344. + min_len (int, optional): The minimum length of the image's shortest edge. Defaults to 672. + Returns: + str or PIL.Image.Image: The processed image. If return_pil is True, a PIL Image object is returned. Otherwise, the processed image is returned as a base64-encoded string. + Raises: + ValueError: If an invalid image_process_mode is provided. + """ + if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image) + elif image_process_mode in ["Default", "Crop"]: + pass + elif image_process_mode == "Resize": + image = image.resize((336, 336)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + if max(image.size) > max_len: + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + return image + else: + buffered = BytesIO() + image.save(buffered, format=image_format) + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + return img_b64_str + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image, image_process_mode = msg + image = self.process_image(image, image_process_mode, return_pil=return_pil) + images.append(image) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, image, image_process_mode = msg + img_b64_str = self.process_image( + image, "Default", return_pil=False, + image_format='JPEG') + img_str = f'user upload image' + msg = img_str + msg.replace('', '').strip() + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + +conv_vicuna_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + ("Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llama_2 = Conversation( + system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_llava_llama_2 = Conversation( + system="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_mpt = Conversation( + system="""<|im_start|>system +A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_llava_plain = Conversation( + system="", + roles=("", ""), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="\n", +) + +conv_llava_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_llava_v0_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "The visual content will be provided with the following format: visual content.", + roles=("Human", "Assistant"), + messages=( + ), + offset=0, + sep_style=SeparatorStyle.SINGLE, + sep="###", + version="v0_mmtag", +) + +conv_llava_v1 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_llava_v1_mmtag = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "The visual content will be provided with the following format: visual content.", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", + version="v1_mmtag", +) + +conv_mistral_instruct = Conversation( + system="", + roles=("USER", "ASSISTANT"), + version="llama_v2", + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA_2, + sep="", + sep2="", +) + +conv_chatml_direct = Conversation( + system="""<|im_start|>system +Answer the questions.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + +conv_templates = { + "default": conv_vicuna_v0, + "v0": conv_vicuna_v0, + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "llama_2": conv_llama_2, + "mistral_instruct": conv_mistral_instruct, + "chatml_direct": conv_chatml_direct, + "mistral_direct": conv_chatml_direct, + + "plain": conv_llava_plain, + "v0_plain": conv_llava_plain, + "llava_v0": conv_llava_v0, + "v0_mmtag": conv_llava_v0_mmtag, + "llava_v1": conv_llava_v1, + "v1_mmtag": conv_llava_v1_mmtag, + "llava_llama_2": conv_llava_llama_2, + + "mpt": conv_mpt, +} \ No newline at end of file diff --git a/ufo/llm/llm_call.py b/ufo/llm/llm_call.py index ec762bbf..ac8312b8 100644 --- a/ufo/llm/llm_call.py +++ b/ufo/llm/llm_call.py @@ -55,7 +55,7 @@ def get_completions(messages, agent: str='APP', use_backup_engine: bool=True, n: api_type = configs[agent_type]['API_TYPE'] try: api_type_lower = api_type.lower() - service = BaseService.get_service(api_type_lower) + service = BaseService.get_service(api_type_lower, configs[agent_type]['API_MODEL'].lower()) if service: response, cost = service(configs, agent_type=agent_type).chat_completion(messages, n) return response, cost