diff --git a/requirements.txt b/requirements.txt index 3b7e32ff..f47f65d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,7 @@ numpy==1.26.4 # WD Tagger huggingface-hub==0.23.4 onnxruntime==1.18.0 + +# Server +fastapi==0.111.0 + diff --git a/taggui/auto_captioning/captioning_core.py b/taggui/auto_captioning/captioning_core.py new file mode 100644 index 00000000..d4ddcf41 --- /dev/null +++ b/taggui/auto_captioning/captioning_core.py @@ -0,0 +1,399 @@ +import gc +import re +from contextlib import nullcontext, redirect_stdout +from pathlib import Path + +import numpy as np +import torch +from PIL import Image as PilImage, UnidentifiedImageError +from PIL.ImageOps import exif_transpose +from transformers import (AutoConfig, AutoModelForCausalLM, + AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, + BatchFeature, BitsAndBytesConfig, + CodeGenTokenizerFast, LlamaTokenizer) + +from auto_captioning.cogvlm2 import (get_cogvlm2_error_message, + get_cogvlm2_inputs) +from auto_captioning.cogvlm_cogagent import (get_cogvlm_cogagent_inputs, + monkey_patch_cogagent, + monkey_patch_cogvlm) +from auto_captioning.florence_2 import get_florence_2_error_message +from auto_captioning.models import get_model_type +from auto_captioning.moondream import (get_moondream_error_message, + get_moondream_inputs, + monkey_patch_moondream1) +from auto_captioning.prompts import (format_prompt, get_default_prompt, + postprocess_prompt_and_generated_text) +from auto_captioning.wd_tagger import WdTaggerModel +from auto_captioning.xcomposer2 import (InternLMXComposer2QuantizedForCausalLM, + get_xcomposer2_error_message, + get_xcomposer2_inputs) +from utils.enums import CaptionDevice, CaptionModelType + + +def get_tokenizer_from_processor(model_type: CaptionModelType, processor): + if model_type in (CaptionModelType.COGAGENT, CaptionModelType.COGVLM, + CaptionModelType.COGVLM2, CaptionModelType.MOONDREAM1, + CaptionModelType.MOONDREAM2, + CaptionModelType.XCOMPOSER2, + CaptionModelType.XCOMPOSER2_4KHD): + return processor + return processor.tokenizer + + +def get_bad_words_ids(bad_words_string: str, + tokenizer) -> list[list[int]] | None: + if not bad_words_string.strip(): + return None + words = re.split(r'(? list[list[list[int]]] | None: + if not forced_words_string.strip(): + return None + word_groups = re.split(r'(? tuple: + # If the processor and model were previously loaded, use them. + processor = self.processor + model = self.model + model_id = self.caption_settings['model'] + # Only GPUs support 4-bit quantization. + load_in_4_bit = (self.caption_settings['load_in_4_bit'] + and device.type == 'cuda') + if self.models_directory_path: + config_path = self.models_directory_path / model_id / 'config.json' + tags_path = (self.models_directory_path / model_id + / 'selected_tags.csv') + if config_path.is_file() or tags_path.is_file(): + model_id = str(self.models_directory_path / model_id) + if (model and self.model_id == model_id + and self.model_device_type == device.type + and self.is_model_loaded_in_4_bit == load_in_4_bit): + return processor, model + # Load the new processor and model. + if model: + # Garbage collect the previous processor and model to free up + # memory. + self.processor = None + self.model = None + del processor + del model + gc.collect() + #print(f'Loading {model_id}...') + if model_type in (CaptionModelType.COGAGENT, CaptionModelType.COGVLM): + processor = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') + elif model_type == CaptionModelType.WD_TAGGER: + processor = None + else: + if model_type == CaptionModelType.MOONDREAM1: + processor_class = CodeGenTokenizerFast + elif model_type in (CaptionModelType.COGVLM2, + CaptionModelType.MOONDREAM2, + CaptionModelType.XCOMPOSER2, + CaptionModelType.XCOMPOSER2_4KHD): + processor_class = AutoTokenizer + else: + processor_class = AutoProcessor + processor = processor_class.from_pretrained(model_id, + trust_remote_code=True) + if model_type in (CaptionModelType.LLAVA_NEXT_34B, + CaptionModelType.LLAVA_NEXT_MISTRAL, + CaptionModelType.LLAVA_NEXT_VICUNA): + processor.tokenizer.padding_side = 'left' + self.processor = processor + if model_type == CaptionModelType.XCOMPOSER2 and load_in_4_bit: + with redirect_stdout(None): + model = InternLMXComposer2QuantizedForCausalLM.from_quantized( + model_id, trust_remote_code=True, device=str(device)) + elif model_type == CaptionModelType.WD_TAGGER: + model = WdTaggerModel(model_id) + else: + if model_type == CaptionModelType.MOONDREAM2: + revision_argument = {'revision': '2024-03-13'} + else: + revision_argument = {} + if load_in_4_bit: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type='nf4', + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True + ) + dtype_argument = {} + if model_type == CaptionModelType.COGVLM2: + config = AutoConfig.from_pretrained(model_id, + trust_remote_code=True) + config.quantization_config = quantization_config + quantization_config_argument = {} + config_argument = {'config': config} + else: + quantization_config_argument = { + 'quantization_config': quantization_config + } + config_argument = {} + else: + dtype_argument = ({'torch_dtype': torch.float16} + if device.type == 'cuda' else {}) + quantization_config_argument = {} + config_argument = {} + model_class = (AutoModelForCausalLM + if model_type in (CaptionModelType.COGAGENT, + CaptionModelType.COGVLM, + CaptionModelType.COGVLM2, + CaptionModelType.FLORENCE_2, + CaptionModelType.MOONDREAM1, + CaptionModelType.MOONDREAM2, + CaptionModelType.XCOMPOSER2, + CaptionModelType.XCOMPOSER2_4KHD) + else AutoModelForVision2Seq) + # Some models print unnecessary messages while loading, so + # temporarily suppress printing for them. + context_manager = ( + redirect_stdout(None) + if model_type in (CaptionModelType.COGAGENT, + CaptionModelType.XCOMPOSER2, + CaptionModelType.XCOMPOSER2_4KHD) + else nullcontext()) + with context_manager: + model = model_class.from_pretrained( + model_id, device_map=device, trust_remote_code=True, + **revision_argument, **dtype_argument, + **quantization_config_argument, **config_argument) + if model_type == CaptionModelType.MOONDREAM1: + model = monkey_patch_moondream1(device, model_id) + if model_type != CaptionModelType.WD_TAGGER: + model.eval() + self.model = model + self.model_id = model_id + self.model_device_type = device.type + self.is_model_loaded_in_4_bit = load_in_4_bit + return processor, model + + def get_prompt(self, model_type: CaptionModelType) -> str | None: + if model_type == CaptionModelType.WD_TAGGER: + return None + prompt = self.caption_settings['prompt'] + if not prompt: + prompt = get_default_prompt(model_type) + prompt = format_prompt(prompt, model_type) + return prompt + + def get_model_inputs(self, pil_image: PilImage.Image, prompt: str | None, + model_type: CaptionModelType, device: torch.device, + model, processor) -> BatchFeature | dict | np.ndarray: + mode = 'RGBA' if model_type == CaptionModelType.WD_TAGGER else 'RGB' + pil_image = pil_image.convert(mode) + if model_type == CaptionModelType.WD_TAGGER: + return model.get_inputs(pil_image) + # Prepare the input text. + caption_start = self.caption_settings['caption_start'] + if model_type in (CaptionModelType.COGAGENT, CaptionModelType.COGVLM): + # `caption_start` is added later. + text = prompt + elif model_type in (CaptionModelType.LLAVA_LLAMA_3, + CaptionModelType.LLAVA_NEXT_34B, + CaptionModelType.XCOMPOSER2, + CaptionModelType.XCOMPOSER2_4KHD): + text = prompt + caption_start + elif prompt and caption_start: + text = f'{prompt} {caption_start}' + else: + text = prompt or caption_start + # Convert the text and image to model inputs. + beam_count = self.caption_settings['generation_parameters'][ + 'num_beams'] + dtype_argument = ({'dtype': torch.float16} + if device.type == 'cuda' else {}) + if model_type in (CaptionModelType.COGAGENT, CaptionModelType.COGVLM): + model_inputs = get_cogvlm_cogagent_inputs( + model_type, model, processor, text, pil_image, beam_count, + device, dtype_argument) + elif model_type == CaptionModelType.COGVLM2: + model_inputs = get_cogvlm2_inputs(model, processor, text, + pil_image, device, + dtype_argument, beam_count) + elif model_type in (CaptionModelType.MOONDREAM1, + CaptionModelType.MOONDREAM2): + model_inputs = get_moondream_inputs( + model, processor, text, pil_image, device, dtype_argument) + elif model_type in (CaptionModelType.XCOMPOSER2, + CaptionModelType.XCOMPOSER2_4KHD): + load_in_4_bit = self.caption_settings['load_in_4_bit'] + model_inputs = get_xcomposer2_inputs( + model_type, model, processor, load_in_4_bit, text, pil_image, + device, dtype_argument) + else: + model_inputs = (processor(text=text, images=pil_image, + return_tensors='pt') + .to(device, **dtype_argument)) + return model_inputs + + def get_caption_from_generated_tokens( + self, generated_token_ids: torch.Tensor, prompt: str, processor, + model_type: CaptionModelType) -> str: + generated_text = processor.batch_decode( + generated_token_ids, skip_special_tokens=True)[0] + prompt, generated_text = postprocess_prompt_and_generated_text( + model_type, processor, prompt, generated_text) + caption_start = self.caption_settings['caption_start'] + if prompt.strip() and generated_text.startswith(prompt): + caption = generated_text[len(prompt):] + elif (caption_start.strip() + and generated_text.startswith(caption_start)): + caption = generated_text + else: + caption = f'{caption_start.strip()} {generated_text.strip()}' + caption = caption.strip() + if self.caption_settings['remove_tag_separators']: + caption = caption.replace(self.tag_separator, ' ') + return caption + + def start_captioning(self) -> str | None: + model_id = self.caption_settings['model'] + model_type = get_model_type(model_id) + forced_words_string = self.caption_settings['forced_words'] + generation_parameters = self.caption_settings[ + 'generation_parameters'] + beam_count = generation_parameters['num_beams'] + if (forced_words_string.strip() and beam_count < 2 + and model_type != CaptionModelType.WD_TAGGER): + error_message = '`Number of beams` must be greater than 1 when `Include in caption` is not empty.' + return error_message + if self.caption_settings['device'] == CaptionDevice.CPU: + device = torch.device('cpu') + else: + gpu_index = self.caption_settings['gpu_index'] + device = torch.device(f'cuda:{gpu_index}' + if torch.cuda.is_available() else 'cpu') + load_in_4_bit = self.caption_settings['load_in_4_bit'] + caption_start = self.caption_settings['caption_start'] + error_message = None + if model_type == CaptionModelType.COGVLM2: + error_message = get_cogvlm2_error_message( + model_id, self.caption_settings['device'], load_in_4_bit) + elif model_type == CaptionModelType.FLORENCE_2: + error_message = get_florence_2_error_message( + caption_settings['prompt'], caption_start) + elif model_type in (CaptionModelType.MOONDREAM1, + CaptionModelType.MOONDREAM2): + beam_count = self.caption_settings['generation_parameters'][ + 'num_beams'] + error_message = get_moondream_error_message(load_in_4_bit, + beam_count) + elif model_type in (CaptionModelType.XCOMPOSER2, + CaptionModelType.XCOMPOSER2_4KHD): + error_message = get_xcomposer2_error_message( + model_id, self.caption_settings['device'], load_in_4_bit) + if error_message: + return error_message + processor, model = self.load_processor_and_model(device, model_type) + # CogVLM and CogAgent have to be monkey patched every time because + # `caption_start` might have changed. + if model_type == CaptionModelType.COGVLM: + monkey_patch_cogvlm(caption_start) + elif model_type == CaptionModelType.COGAGENT: + monkey_patch_cogagent(model, caption_start) + self.processor = processor + self.model = model + self.model_type = model_type + self.device = device + + def run_captioning(self, pil_image: PilImage.Image) -> tuple[bool, str, str]: + forced_words_string = self.caption_settings['forced_words'].strip() + generation_parameters = self.caption_settings['generation_parameters'] + prompt = self.get_prompt(self.model_type) + try: + model_inputs = self.get_model_inputs(pil_image, prompt, self.model_type, + self.device, self.model, self.processor) + except UnidentifiedImageError: + error_message = f'Image file format is not supported or it is a corrupted image.' + return False, error_message, "" + console_output_caption = None + if self.model_type == CaptionModelType.WD_TAGGER: + wd_tagger_settings = self.caption_settings[ + 'wd_tagger_settings'] + tags, probabilities = self.model.generate_tags(model_inputs, + wd_tagger_settings) + caption = self.tag_separator.join(tags) + if wd_tagger_settings['show_probabilities']: + console_output_caption = self.tag_separator.join( + f'{tag} ({probability:.2f})' + for tag, probability in zip(tags, probabilities) + ) + else: + generation_model = ( + self.model.text_model + if self.model_type in (CaptionModelType.MOONDREAM1, + CaptionModelType.MOONDREAM2) + else self.model + ) + bad_words_string = self.caption_settings['bad_words'] + tokenizer = get_tokenizer_from_processor(self.model_type, self.processor) + bad_words_ids = get_bad_words_ids(bad_words_string, tokenizer) + forced_words_ids = get_forced_words_ids(forced_words_string, + tokenizer) + if self.model_type == CaptionModelType.COGVLM2: + special_generation_parameters = {'pad_token_id': 128002} + elif self.model_type == CaptionModelType.LLAVA_LLAMA_3: + eos_token_id = (tokenizer('<|eot_id|>', + add_special_tokens=False) + .input_ids)[0] + special_generation_parameters = { + 'eos_token_id': eos_token_id + } + else: + special_generation_parameters = {} + with torch.inference_mode(): + generated_token_ids = generation_model.generate( + **model_inputs, bad_words_ids=bad_words_ids, + force_words_ids=forced_words_ids, + **generation_parameters, + **special_generation_parameters) + caption = self.get_caption_from_generated_tokens( + generated_token_ids, prompt, self.processor, self.model_type) + + if console_output_caption is None: + console_output_caption = caption + return True, console_output_caption, caption diff --git a/taggui/run_server.py b/taggui/run_server.py new file mode 100644 index 00000000..2616b2e1 --- /dev/null +++ b/taggui/run_server.py @@ -0,0 +1,186 @@ +# configs +appname = "taggui server" +port = 11435 # Ollama port=11434 + +# ---- + +# inspired by +# https://github.com/ollama/ollama +# https://github.com/ollama/ollama/blob/main/docs/api.md + +import base64 +from fastapi import FastAPI, Request +from pydantic import BaseModel +import sys +import uvicorn +import requests +from io import BytesIO +from PIL import Image as PilImage +from PIL.ImageOps import exif_transpose + +from auto_captioning.captioning_core import CaptioningCore +from auto_captioning.models import MODELS + +app = FastAPI() +caption_settings = { + 'model': "", + 'prompt': "", + 'caption_start': "", + 'caption_position': "", + 'device': "cuda:0", + 'gpu_index': 0, + 'load_in_4_bit': True, + 'remove_tag_separators': True, + 'bad_words': "", + 'forced_words': "", + 'generation_parameters': { + 'min_new_tokens': 1, + 'max_new_tokens': 100, + 'num_beams': 1, + 'length_penalty': 1, + 'do_sample': False, + 'temperature': 1, + 'top_k': 50, + 'top_p': 1, + 'repetition_penalty': 1, + 'no_repeat_ngram_size': 3, + }, + 'wd_tagger_settings': { + 'show_probabilities': True, + 'min_probability': 0.4, + 'max_tags': 30, + 'tags_to_exclude': "", + } +} +tag_separator = "," +models_directory_path = None +core = CaptioningCore(caption_settings, tag_separator, models_directory_path) + +class TextInput(BaseModel): + prompt: str + img_path: str + +@app.get("/") +async def index(): + return appname + +@app.post("/api/generate") +async def prompt(request: Request): + caption = "" + try: + caption_settings = await request.json() + if "images" not in caption_settings: raise Exception("missing 'images'") + core.caption_settings.update(caption_settings) + if core.device == None or core.model == None or core.processor == None or core.model_type == None: + core.start_captioning() + if len(caption_settings["images"]) > 0: + img_bytes = base64.b64decode(caption_settings["images"][0]) + pil_image = PilImage.open(BytesIO(img_bytes)) + pil_image = exif_transpose(pil_image) + success, msg, caption = core.run_captioning(pil_image) + if not success: raise Exception(msg) + except Exception as e: + return { "type": "error", "msg": str(e) } + return { "type": "generate", "response": caption } + +def run_cli(): + welcome = "Send a message path/to/image.png (/? for help)" + cli_usage = """ +Available Commands: + /bye Exit + /?, /help Help for a command +""" + print(welcome) + text_input = "" + while True: + try: + text_input = input(">>> ") + if text_input == "": continue + if text_input == "/bye": break + if text_input in ["/?", "/help"]: + print(cli_usage) + continue + + i = text_input.rfind(' ') + if i >= 0: + img_prompt = text_input[:i] + img_path = text_input[i + 1:] + else: + img_prompt = "" + img_path = text_input + + try: + with open(img_path, 'rb') as img_file: + img = PilImage.open(img_file) + img.verify() + img_file.seek(0) + img_base64 = base64.b64encode(img_file.read()).decode('utf-8') + except Exception as e: + print(e) + continue + + response = requests.post(f"http://127.0.0.1:{port}/api/generate", json={ "prompt": img_prompt, "images": [str(img_base64)] }).json() + if response["type"] == "error": raise Exception(response["msg"]) + print(response["response"]) + except KeyboardInterrupt: + if text_input == "": + print("Use Ctrl + d or /bye to exit.") + except EOFError: + break + except Exception as e: + print(e) + +if __name__ == "__main__": + appcommand = "python run_server.py" + commands = ["serve", "run"] + models = '\n '.join(MODELS) + usage = f""" +Usage: +{appcommand} [command] + +Available Commands: +serve Start {appname} +run Run a model +""" + run_usage = f""" +Run a model + +Usage: + {appcommand} run MODEL + +Models: + {models} +""" + run_error = f"could not connect to {appname}, is it running?" + + if len(sys.argv) <= 1: + print(usage) + quit() + + command = sys.argv[1] + if not command in commands: + print(usage) + quit() + + if command == "serve": + uvicorn.run(app, host="127.0.0.1", port=port) + + if command == "run": + if len(sys.argv) <= 2: + print(run_usage) + quit() + model_id = sys.argv[2] + if not model_id in MODELS: + print(f"Unknown model_id='{model_id}'. use one of:\n{models}") + quit() + + try: + print(f"Loading {model_id}...") + + response = requests.post(f"http://127.0.0.1:{port}/api/generate", json={ "model": model_id, "images": [] }).json() + if response["type"] == "error": raise Exception(response["msg"]) + except Exception as e: + print(e) + quit() + + run_cli()