Skip to content

Commit

Permalink
Add new blip2 caption processor tool
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Aug 15, 2023
1 parent 24d0017 commit 940302c
Show file tree
Hide file tree
Showing 10 changed files with 286 additions and 13 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,8 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is
* 2023/08/05 (v21.8.8)
- Fix issue with aiofiles: https://github.com/bmaltais/kohya_ss/issues/1359
- Merge sd-scripts updates as of Aug 11 2023
- Add new blip2 caption processor tool
- Add dataset preparation tab to appropriate trainers
* 2023/08/05 (v21.8.7)
- Add manual captioning option. Thanks to https://github.com/channelcat for this great contribution. (https://github.com/bmaltais/kohya_ss/pull/1352)
- Added support for `v_pred_like_loss` to the advanced training tab
4 changes: 3 additions & 1 deletion dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab
from library.class_sample_images import SampleImages, run_cmd_sample

Expand Down Expand Up @@ -729,7 +730,7 @@ def dreambooth_tab(
with gr.Tab('Samples', elem_id='samples_tab'):
sample = SampleImages()

with gr.Tab('Tools'):
with gr.Tab('Dataset Preparation'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
Expand All @@ -740,6 +741,7 @@ def dreambooth_tab(
logging_dir_input=folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)

with gr.Row():
button_run = gr.Button('Start training', variant='primary')
Expand Down
4 changes: 3 additions & 1 deletion library/class_dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from .common_gui import color_aug_changed

class Dreambooth:
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(

self.sample = SampleImages()

with gr.Tab('Tools'):
with gr.Tab('Dataset Preparation'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
Expand All @@ -60,6 +61,7 @@ def __init__(
logging_dir_input=self.folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)

def save_to_json(self, filepath):
def serialize(obj):
Expand Down
2 changes: 1 addition & 1 deletion library/class_lora_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, folders = "", headless:bool = False):
gradio_resize_lora_tab(headless=headless)
gradio_verify_lora_tab(headless=headless)
if folders:
with gr.Tab('Deprecated'):
with gr.Tab('Dataset Preparation'):
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
Expand Down
28 changes: 19 additions & 9 deletions lora_gui.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
# v1: initial release
# v2: add open and save folder icons
# v3: Add new Utilities tab for Dreambooth folder preparation
# v3.1: Adding captionning of images to utilities

import gradio as gr
import json
import math
import os
import subprocess
import psutil
import pathlib
import argparse
from datetime import datetime
from library.common_gui import (
get_file_path,
get_any_file_path,
get_saveasfile_path,
color_aug_changed,
save_inference_file,
run_cmd_advanced_training,
run_cmd_training,
update_my_data,
Expand All @@ -44,6 +35,11 @@
from library.class_sample_images import SampleImages, run_cmd_sample
from library.class_lora_tab import LoRATools

from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab

from library.custom_logging import setup_logging

# Set up logging
Expand Down Expand Up @@ -1416,6 +1412,20 @@ def update_LoRA_settings(LoRA_type):
module_dropout,
],
)

with gr.Tab('Dataset Preparation'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)


with gr.Row():
button_run = gr.Button('Start training', variant='primary')
Expand Down
4 changes: 3 additions & 1 deletion textual_inversion_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab
from library.class_sample_images import SampleImages, run_cmd_sample

Expand Down Expand Up @@ -787,7 +788,7 @@ def ti_tab(
with gr.Tab('Samples', elem_id='samples_tab'):
sample = SampleImages()

with gr.Tab('Tools'):
with gr.Tab('Dataset Preparation'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
Expand All @@ -798,6 +799,7 @@ def ti_tab(
logging_dir_input=folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)

with gr.Row():
button_run = gr.Button('Start training', variant='primary')
Expand Down
33 changes: 33 additions & 0 deletions tools/blip2-for-sd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# blip2-for-sd

source: https://github.com/Talmendo/blip2-for-sd

Simple script to make BLIP2 output image description in a format suitable for Stable Diffusion.

Format followd is roughly
`[STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER]`

## Usage
- Install dependencies according to requirements.txt

- run main.py
`python main.py`

The default model will be loaded automatically from huggingface.
You will be presented with an input to specify the folder to process after the model is loaded.

<img width="854" alt="Screenshot 2023-08-04 102650" src="https://github.com/Talmendo/blip2-for-sd/assets/141401796/fa40cae5-90a4-4dd5-be1d-fc0e8312251a">


- The image or source folder should have the following structure:

![Screenshot 2023-08-04 102544](https://github.com/Talmendo/blip2-for-sd/assets/141401796/eea9c2b0-e96a-40e4-8a6d-32dd7aa3e802)


Each folder represents a base prompt to be used for every image inside.

- You can adjust BLIP2 settings in `caption_processor.py` inbetween runs, without having to stop the script. Just update it before inputting the new source folder.

## Models
Default model is `Salesforce/blip2-opt-2.7b`, works quite well and doesn't require much VRAM.
Also tested with `Salesforce/blip2-opt-6.7b-coco` which seems to gives better results at the cost of much more VRAM and a large download (~30GB).
105 changes: 105 additions & 0 deletions tools/blip2-for-sd/caption_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import re

class CaptionProcessor:
def __init__(self, model, processor, device):
self.model = model
self.processor = processor
self.device = device

def gen(self, inputs, max_length=10, min_length=0, top_k=30, top_p=0.92, num_beams=4):
return self.model.generate(
**inputs,
# max_new_tokens=25, # Number of tokens to generate
max_length=max_length, # Maximum length of the sequence to be generated, mutually exclusive with max_new_tokens
num_beams=num_beams, # Number of beams to use for beam search
num_return_sequences=1, # Number of captions to generate
early_stopping=True, # Stop when no new tokens are generated
repetition_penalty=1.5, # Penalize repeated words
no_repeat_ngram_size=2, # Number of words that can be repeated
# do_sample=True, # Introduce randomness to captions
# temperature=0.9, # Measure of randomness 0-1, 0 means no randomness
top_k=top_k, # Number of highest probability tokens to keep, 0 means no filtering
top_p=top_p, # Probability threshold, 0 means no filtering
min_length=min_length, # Minimum length of the sequence to be generated
)

def process(self, prompt, image):
return self.processor(image, text=prompt, return_tensors="pt").to(self.device, torch.float16)

def caption_from(self, generated):
caption_list = self.processor.batch_decode(generated, skip_special_tokens=True)
caption_list = [caption.strip() for caption in caption_list]
return caption_list if len(caption_list) > 1 else caption_list[0]

def sanitise_caption(self, caption):
return caption.replace(" - ", "-")

# TODO this needs some more work
def sanitise_prompt_shard(self, prompt):
# Remove everything after "Answer:"
prompt = prompt.split("Answer:")[0].strip()

# Define a pattern for multiple replacements
replacements = [
(r", a point and shoot(?: camera)?", ""), # Matches ", a point and shoot" with optional " camera"
(r"it is a ", ""),
(r"it is ", ""),
(r"hair hair", "hair"),
(r"wearing nothing", "nude"),
(r"She's ", ""),
(r"She is ", "")
]

# Apply the replacements using regex
for pattern, replacement in replacements:
prompt = re.sub(pattern, replacement, prompt)

return prompt

def ask(self, question, image):
return self.sanitise_prompt_shard(self.caption_from(self.gen(self.process(f"Question: {question} Answer:", image))))

def caption_me(self, initial_prompt, image):
prompt = ""

try:
# [STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER]
# print("\n")
hair_color = self.ask("What is her hair color?", image)
hair_length = self.ask("What is her hair length?", image)
p_hair = f"{hair_color} {hair_length} hair"
# print(p_hair)

p_style = self.ask("Between the choices selfie, mirror selfie, candid, professional portrait what is the style of the photo?", image)
# print(p_style)

p_clothing = self.ask("What is she wearing if anything?", image)
# print(p_clothing)

p_action = self.ask("What is she doing? Could be something like standing, stretching, walking, squatting, etc", image)
# print(p_action)

p_framing = self.ask("Between the choices close up, upper body shot, full body shot what is the framing of the photo?", image)
# print(p_framing)

p_setting = self.ask("Where is she? Be descriptive and detailed", image)
# print(p_setting)

p_lighting = self.ask("What is the scene lighting like? For example: soft lighting, studio lighting, natural lighting", image)
# print(p_lighting)

p_angle = self.ask("What angle is the picture taken from? Be succint, like: from the side, from below, from front", image)
# print(p_angle)

p_camera = self.ask("What kind of camera could this picture have been taken with? Be specific and guess a brand with specific camera type", image)
# print(p_camera)

# prompt = self.sanitise_caption(f"{p_style}, {initial_prompt} with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}")
prompt = self.sanitise_caption(f"{p_style}, with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}")

return prompt
except Exception as e:
print(e)

return prompt
89 changes: 89 additions & 0 deletions tools/blip2-for-sd/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import requests, torch, sys, os
import argparse

from importlib import reload
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from tqdm import tqdm

import caption_processor

model = None
processor = None
device = None

def load_model(model_name="Salesforce/blip2-opt-2.7b"):
global model, processor, device

print("Loading Model")
processor = AutoProcessor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float16)

if torch.cuda.is_available():
print("CUDA available, using GPU")
device = "cuda"
else:
print("CUDA not available, using CPU")
device = "cpu"

print("Moving model to device")
model.to(device)

def main(path):
# reloading caption_processor to enable us to change its values in between executions
# without having to reload the model, which can take very long
# probably cleaner to do this with a config file and just reload that
# but this works for now
reload(caption_processor)
prompt_file_dict = {}

# list all sub dirs in path
sub_dirs = [dir for dir in os.listdir(path) if os.path.isdir(os.path.join(path, dir))]

print("Reading prompts from sub dirs and finding image files")
for prompt in sub_dirs:
prompt_file_dict[prompt] = [file for file in os.listdir(os.path.join(path, prompt)) if file.endswith((".jpg", ".png", ".jpeg", ".webp"))]

for prompt, file_list in prompt_file_dict.items():
print(f"Found {str(len(file_list))} files for prompt \"{prompt}\"")

for prompt, file_list in prompt_file_dict.items():
total = len(file_list)

for file in tqdm(file_list):
# read image
image = Image.open(os.path.join(path, prompt, file))

caption = ""
# generate caption
try:
caption = caption_processor.CaptionProcessor(model, processor, device).caption_me(prompt, image)
except:
print("Error creating caption for file: " + file)

# save caption to file
# file without extension
with open(os.path.join(path, prompt, os.path.splitext(file)[0] + ".txt"), "w", encoding="utf-8") as f:
f.write(caption)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Enter the path to the file")
parser.add_argument("path", type=str, nargs='?', default="", help="Path to the file")
parser.add_argument("--interactive", action="store_true", help="Interactive mode")

args = parser.parse_args()
interactive = args.interactive

load_model(model_name="Salesforce/blip2-opt-2.7b")

if interactive:
while True:
path = input("Enter path: ")
main(path)
continue_prompt = input("Continue? (y/n): ")
if continue_prompt.lower() != 'y':
break
else:
path = args.path
search_subdirectories = False
main(path)
Loading

0 comments on commit 940302c

Please sign in to comment.