-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new blip2 caption processor tool
- Loading branch information
Showing
10 changed files
with
286 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# blip2-for-sd | ||
|
||
source: https://github.com/Talmendo/blip2-for-sd | ||
|
||
Simple script to make BLIP2 output image description in a format suitable for Stable Diffusion. | ||
|
||
Format followd is roughly | ||
`[STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER]` | ||
|
||
## Usage | ||
- Install dependencies according to requirements.txt | ||
|
||
- run main.py | ||
`python main.py` | ||
|
||
The default model will be loaded automatically from huggingface. | ||
You will be presented with an input to specify the folder to process after the model is loaded. | ||
|
||
<img width="854" alt="Screenshot 2023-08-04 102650" src="https://github.com/Talmendo/blip2-for-sd/assets/141401796/fa40cae5-90a4-4dd5-be1d-fc0e8312251a"> | ||
|
||
|
||
- The image or source folder should have the following structure: | ||
|
||
![Screenshot 2023-08-04 102544](https://github.com/Talmendo/blip2-for-sd/assets/141401796/eea9c2b0-e96a-40e4-8a6d-32dd7aa3e802) | ||
|
||
|
||
Each folder represents a base prompt to be used for every image inside. | ||
|
||
- You can adjust BLIP2 settings in `caption_processor.py` inbetween runs, without having to stop the script. Just update it before inputting the new source folder. | ||
|
||
## Models | ||
Default model is `Salesforce/blip2-opt-2.7b`, works quite well and doesn't require much VRAM. | ||
Also tested with `Salesforce/blip2-opt-6.7b-coco` which seems to gives better results at the cost of much more VRAM and a large download (~30GB). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import torch | ||
import re | ||
|
||
class CaptionProcessor: | ||
def __init__(self, model, processor, device): | ||
self.model = model | ||
self.processor = processor | ||
self.device = device | ||
|
||
def gen(self, inputs, max_length=10, min_length=0, top_k=30, top_p=0.92, num_beams=4): | ||
return self.model.generate( | ||
**inputs, | ||
# max_new_tokens=25, # Number of tokens to generate | ||
max_length=max_length, # Maximum length of the sequence to be generated, mutually exclusive with max_new_tokens | ||
num_beams=num_beams, # Number of beams to use for beam search | ||
num_return_sequences=1, # Number of captions to generate | ||
early_stopping=True, # Stop when no new tokens are generated | ||
repetition_penalty=1.5, # Penalize repeated words | ||
no_repeat_ngram_size=2, # Number of words that can be repeated | ||
# do_sample=True, # Introduce randomness to captions | ||
# temperature=0.9, # Measure of randomness 0-1, 0 means no randomness | ||
top_k=top_k, # Number of highest probability tokens to keep, 0 means no filtering | ||
top_p=top_p, # Probability threshold, 0 means no filtering | ||
min_length=min_length, # Minimum length of the sequence to be generated | ||
) | ||
|
||
def process(self, prompt, image): | ||
return self.processor(image, text=prompt, return_tensors="pt").to(self.device, torch.float16) | ||
|
||
def caption_from(self, generated): | ||
caption_list = self.processor.batch_decode(generated, skip_special_tokens=True) | ||
caption_list = [caption.strip() for caption in caption_list] | ||
return caption_list if len(caption_list) > 1 else caption_list[0] | ||
|
||
def sanitise_caption(self, caption): | ||
return caption.replace(" - ", "-") | ||
|
||
# TODO this needs some more work | ||
def sanitise_prompt_shard(self, prompt): | ||
# Remove everything after "Answer:" | ||
prompt = prompt.split("Answer:")[0].strip() | ||
|
||
# Define a pattern for multiple replacements | ||
replacements = [ | ||
(r", a point and shoot(?: camera)?", ""), # Matches ", a point and shoot" with optional " camera" | ||
(r"it is a ", ""), | ||
(r"it is ", ""), | ||
(r"hair hair", "hair"), | ||
(r"wearing nothing", "nude"), | ||
(r"She's ", ""), | ||
(r"She is ", "") | ||
] | ||
|
||
# Apply the replacements using regex | ||
for pattern, replacement in replacements: | ||
prompt = re.sub(pattern, replacement, prompt) | ||
|
||
return prompt | ||
|
||
def ask(self, question, image): | ||
return self.sanitise_prompt_shard(self.caption_from(self.gen(self.process(f"Question: {question} Answer:", image)))) | ||
|
||
def caption_me(self, initial_prompt, image): | ||
prompt = "" | ||
|
||
try: | ||
# [STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER] | ||
# print("\n") | ||
hair_color = self.ask("What is her hair color?", image) | ||
hair_length = self.ask("What is her hair length?", image) | ||
p_hair = f"{hair_color} {hair_length} hair" | ||
# print(p_hair) | ||
|
||
p_style = self.ask("Between the choices selfie, mirror selfie, candid, professional portrait what is the style of the photo?", image) | ||
# print(p_style) | ||
|
||
p_clothing = self.ask("What is she wearing if anything?", image) | ||
# print(p_clothing) | ||
|
||
p_action = self.ask("What is she doing? Could be something like standing, stretching, walking, squatting, etc", image) | ||
# print(p_action) | ||
|
||
p_framing = self.ask("Between the choices close up, upper body shot, full body shot what is the framing of the photo?", image) | ||
# print(p_framing) | ||
|
||
p_setting = self.ask("Where is she? Be descriptive and detailed", image) | ||
# print(p_setting) | ||
|
||
p_lighting = self.ask("What is the scene lighting like? For example: soft lighting, studio lighting, natural lighting", image) | ||
# print(p_lighting) | ||
|
||
p_angle = self.ask("What angle is the picture taken from? Be succint, like: from the side, from below, from front", image) | ||
# print(p_angle) | ||
|
||
p_camera = self.ask("What kind of camera could this picture have been taken with? Be specific and guess a brand with specific camera type", image) | ||
# print(p_camera) | ||
|
||
# prompt = self.sanitise_caption(f"{p_style}, {initial_prompt} with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}") | ||
prompt = self.sanitise_caption(f"{p_style}, with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}") | ||
|
||
return prompt | ||
except Exception as e: | ||
print(e) | ||
|
||
return prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import requests, torch, sys, os | ||
import argparse | ||
|
||
from importlib import reload | ||
from PIL import Image | ||
from transformers import AutoProcessor, Blip2ForConditionalGeneration | ||
from tqdm import tqdm | ||
|
||
import caption_processor | ||
|
||
model = None | ||
processor = None | ||
device = None | ||
|
||
def load_model(model_name="Salesforce/blip2-opt-2.7b"): | ||
global model, processor, device | ||
|
||
print("Loading Model") | ||
processor = AutoProcessor.from_pretrained(model_name) | ||
model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float16) | ||
|
||
if torch.cuda.is_available(): | ||
print("CUDA available, using GPU") | ||
device = "cuda" | ||
else: | ||
print("CUDA not available, using CPU") | ||
device = "cpu" | ||
|
||
print("Moving model to device") | ||
model.to(device) | ||
|
||
def main(path): | ||
# reloading caption_processor to enable us to change its values in between executions | ||
# without having to reload the model, which can take very long | ||
# probably cleaner to do this with a config file and just reload that | ||
# but this works for now | ||
reload(caption_processor) | ||
prompt_file_dict = {} | ||
|
||
# list all sub dirs in path | ||
sub_dirs = [dir for dir in os.listdir(path) if os.path.isdir(os.path.join(path, dir))] | ||
|
||
print("Reading prompts from sub dirs and finding image files") | ||
for prompt in sub_dirs: | ||
prompt_file_dict[prompt] = [file for file in os.listdir(os.path.join(path, prompt)) if file.endswith((".jpg", ".png", ".jpeg", ".webp"))] | ||
|
||
for prompt, file_list in prompt_file_dict.items(): | ||
print(f"Found {str(len(file_list))} files for prompt \"{prompt}\"") | ||
|
||
for prompt, file_list in prompt_file_dict.items(): | ||
total = len(file_list) | ||
|
||
for file in tqdm(file_list): | ||
# read image | ||
image = Image.open(os.path.join(path, prompt, file)) | ||
|
||
caption = "" | ||
# generate caption | ||
try: | ||
caption = caption_processor.CaptionProcessor(model, processor, device).caption_me(prompt, image) | ||
except: | ||
print("Error creating caption for file: " + file) | ||
|
||
# save caption to file | ||
# file without extension | ||
with open(os.path.join(path, prompt, os.path.splitext(file)[0] + ".txt"), "w", encoding="utf-8") as f: | ||
f.write(caption) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Enter the path to the file") | ||
parser.add_argument("path", type=str, nargs='?', default="", help="Path to the file") | ||
parser.add_argument("--interactive", action="store_true", help="Interactive mode") | ||
|
||
args = parser.parse_args() | ||
interactive = args.interactive | ||
|
||
load_model(model_name="Salesforce/blip2-opt-2.7b") | ||
|
||
if interactive: | ||
while True: | ||
path = input("Enter path: ") | ||
main(path) | ||
continue_prompt = input("Continue? (y/n): ") | ||
if continue_prompt.lower() != 'y': | ||
break | ||
else: | ||
path = args.path | ||
search_subdirectories = False | ||
main(path) |
Oops, something went wrong.