Skip to content

Commit

Permalink
fixed issue with running ControlNet
Browse files Browse the repository at this point in the history
  • Loading branch information
volotat committed May 5, 2023
1 parent 7e55f4d commit 9b88a8f
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 31 deletions.
12 changes: 3 additions & 9 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SD-CN-Animation
This project allows you to automate video stylization task using StableDiffusion and ControlNet. It also allows you to generate completely new videos from text at any resolution and length in contrast to other current text2video methods using any Stable Diffusion model as a backbone, including custom ones. It uses '[RAFT](https://github.com/princeton-vl/RAFT)' optical flow estimation algorithm to keep the animation stable and create an inpainting mask that is used to generate the next frame. In text to video mode it relies on 'FloweR' method (work in progress) that predicts optical flow from the previous frames.

This project allows you to automate video stylization task using StableDiffusion and ControlNet. It also allows you to generate completely new videos from text at any resolution and length in contrast to other current text2video methods using any Stable Diffusion model as a backbone, including custom ones. It uses '[RAFT](https://github.com/princeton-vl/RAFT)' optical flow estimation algorithm to keep the animation stable and create an occlusion mask that is used to generate the next frame. In text to video mode it relies on 'FloweR' method (work in progress) that predicts optical flow from the previous frames.

### Video to Video Examples:
</table>
Expand Down Expand Up @@ -46,17 +45,12 @@ Examples presented are generated at 1024x576 resolution using the 'realisticVisi

All examples you can see here are originally generated at 512x512 resolution using the 'sd-v1-5-inpainting' model as a base. They were downsized and compressed for better loading speed. You can see them in their original quality in the 'examples' folder. Actual prompts used were stated in the following format: "RAW photo, {subject}, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3", only the 'subject' part is described in the table above.


## Installing the extension
*TODO*

Download RAFT 'raft-things.pth' from here: https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT and place it into 'stable-diffusion-webui/models/RAFT/' folder.
All generated video will be saved into 'outputs/sd-cn-animation' folder.
To install the extension go to 'Extensions' tab in [Automatic1111 web-ui](https://github.com/AUTOMATIC1111/stable-diffusion-webui), then go to 'Install from URL' tab. In 'URL for extension's git repository' field inter the path to this repository, i.e. 'https://github.com/volotat/SD-CN-Animation'. Leave 'Local directory name' field empty. Then just press 'Install' button. Download RAFT 'raft-things.pth' model from here: [Google Drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT) and place it into 'stable-diffusion-webui/models/RAFT/' folder. Restart web-ui, new 'SD-CN-Animation' tab should appear. All generated video will be saved into 'stable-diffusion-webui/outputs/sd-cn-animation' folder.

## Last version changes: v0.6
* Complete rewrite of the project to make it possible to install as an Automatic1111/Web-ui extension.
* Added separate flag '-rb' for background removal process at the flow computation stage in the compute_flow.py script.
* Added flow normalization before rescaling it, so the magnitude of the flow computed correctly at the different resolution.
* Less ghosting and color change in vid2vid mode
* Less ghosting and color drift in vid2vid mode
* Added "warped styled frame fix" at vid2vid mode that removes image duplicated from the parts of the image that cannot be relocated from the optical flow.

1 change: 1 addition & 0 deletions scripts/base_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def on_ui_tabs():
components['cfg_scale'], # cfg_scale
dummy_component, # image_cfg_scale
components['processing_strength'], # denoising_strength
components['fix_frame_strength'], # fix_frame_strength
components['seed'], # seed
dummy_component, # subseed
dummy_component, # subseed_strength
Expand Down
6 changes: 3 additions & 3 deletions scripts/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_sty
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)

diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
#diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
#diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)

alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 4, diff_mask_stl * 2)
alpha_mask = np.maximum(occlusion_mask * 0.3, diff_mask_org * 3) #, diff_mask_stl * 2
alpha_mask = alpha_mask.repeat(3, axis = -1)

#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
Expand Down
117 changes: 98 additions & 19 deletions scripts/vid2vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from modules import shared, sd_hijack, lowvram
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, devices
from modules.shared import opts, devices, state
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
Expand All @@ -50,6 +50,7 @@ class sdcn_anim_tmp:
curr_frame = None
prev_frame = None
prev_frame_styled = None
prev_frame_alpha_mask = None
fps = None
total_frames = None
prepared_frames = None
Expand Down Expand Up @@ -92,7 +93,7 @@ def get_device():
return device

def args_to_dict(*args): # converts list of argumets into dictionary for better handling of it
args_list = ['id_task', 'mode', 'prompt', 'negative_prompt', 'prompt_styles', 'init_video', 'sketch', 'init_img_with_mask', 'inpaint_color_sketch', 'inpaint_color_sketch_orig', 'init_img_inpaint', 'init_mask_inpaint', 'steps', 'sampler_index', 'mask_blur', 'mask_alpha', 'inpainting_fill', 'restore_faces', 'tiling', 'n_iter', 'batch_size', 'cfg_scale', 'image_cfg_scale', 'denoising_strength', 'seed', 'subseed', 'subseed_strength', 'seed_resize_from_h', 'seed_resize_from_w', 'seed_enable_extras', 'height', 'width', 'resize_mode', 'inpaint_full_res', 'inpaint_full_res_padding', 'inpainting_mask_invert', 'img2img_batch_input_dir', 'img2img_batch_output_dir', 'img2img_batch_inpaint_mask_dir', 'override_settings_texts']
args_list = ['id_task', 'mode', 'prompt', 'negative_prompt', 'prompt_styles', 'init_video', 'sketch', 'init_img_with_mask', 'inpaint_color_sketch', 'inpaint_color_sketch_orig', 'init_img_inpaint', 'init_mask_inpaint', 'steps', 'sampler_index', 'mask_blur', 'mask_alpha', 'inpainting_fill', 'restore_faces', 'tiling', 'n_iter', 'batch_size', 'cfg_scale', 'image_cfg_scale', 'denoising_strength', 'fix_frame_strength', 'seed', 'subseed', 'subseed_strength', 'seed_resize_from_h', 'seed_resize_from_w', 'seed_enable_extras', 'height', 'width', 'resize_mode', 'inpaint_full_res', 'inpaint_full_res_padding', 'inpainting_mask_invert', 'img2img_batch_input_dir', 'img2img_batch_output_dir', 'img2img_batch_inpaint_mask_dir', 'override_settings_texts']

# set default values for params that were not specified
args_dict = {
Expand All @@ -114,6 +115,7 @@ def args_to_dict(*args): # converts list of argumets into dictionary for better
'cfg_scale': 5.5,
'image_cfg_scale': 1.5,
'denoising_strength': 0.75,
'fix_frame_strength': 0.15,
'seed': -1,
'subseed': -1,
'subseed_strength': 0,
Expand All @@ -138,6 +140,8 @@ def args_to_dict(*args): # converts list of argumets into dictionary for better
args_dict['script_inputs'] = args[len(args_list):]
return args_dict, args

# TODO: Refactor all the code below

def start_process(*args):
args_dict, args_list = args_to_dict(*args)

Expand All @@ -163,9 +167,9 @@ def start_process(*args):
sdcn_anim_tmp.prepared_prev_flows = np.zeros((10, args_dict['height'], args_dict['width'], 2))
sdcn_anim_tmp.prepared_frames[0] = curr_frame

#args_dict['init_img'] = cur_frame
args_list[5] = Image.fromarray(curr_frame)
processed_frames, _, _, _ = modules.img2img.img2img(*args_list) #img2img(args_dict)
args_dict['init_img'] = Image.fromarray(curr_frame)
#args_list[5] = Image.fromarray(curr_frame)
processed_frames, _, _, _ = img2img(args_dict)
processed_frame = np.array(processed_frames[0])
processed_frame = skimage.exposure.match_histograms(processed_frame, curr_frame, multichannel=False, channel_axis=-1)
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
Expand All @@ -177,8 +181,6 @@ def start_process(*args):
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()
yield get_cur_stat(), sdcn_anim_tmp.curr_frame, None, None, processed_frame, ''

# TODO: SOLVE PROBLEM with wrong prev frame on the start on new processing iterations

for step in range((sdcn_anim_tmp.total_frames-1) * 2):
args_dict, args_list = args_to_dict(*args)

Expand Down Expand Up @@ -229,14 +231,20 @@ def start_process(*args):
prev_flow = sdcn_anim_tmp.prepared_prev_flows[cn]

# process current frame
args_list[5] = Image.fromarray(curr_frame)
args_list[24] = -1
processed_frames, _, _, _ = modules.img2img.img2img(*args_list)
args_dict['init_img'] = Image.fromarray(curr_frame)
args_dict['seed'] = -1
#args_list[5] = Image.fromarray(curr_frame)
#args_list[24] = -1
processed_frames, _, _, _ = img2img(args_dict)
processed_frame = np.array(processed_frames[0])


alpha_mask, warped_styled_frame = compute_diff_map(next_flow, prev_flow, prev_frame, curr_frame, sdcn_anim_tmp.prev_frame_styled)
alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
if sdcn_anim_tmp.process_counter > 0:
alpha_mask = alpha_mask + sdcn_anim_tmp.prev_frame_alpha_mask * 0.5
sdcn_anim_tmp.prev_frame_alpha_mask = alpha_mask
# alpha_mask = np.clip(alpha_mask + 0.05, 0.05, 0.95)
alpha_mask = np.clip(alpha_mask, 0, 1)

fl_w, fl_h = prev_flow.shape[:2]
prev_flow_n = prev_flow / np.array([fl_h,fl_w])
Expand All @@ -258,10 +266,13 @@ def start_process(*args):
processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
sdcn_anim_tmp.prev_frame_styled = processed_frame.copy()

args_list[5] = Image.fromarray(processed_frame)
args_list[23] = 0.15
args_list[24] = 8888
processed_frames, _, _, _ = modules.img2img.img2img(*args_list)
args_dict['init_img'] = Image.fromarray(processed_frame)
args_dict['denoising_strength'] = args_dict['fix_frame_strength']
args_dict['seed'] = 8888
#args_list[5] = Image.fromarray(processed_frame)
#args_list[23] = 0.15
#args_list[24] = 8888
processed_frames, _, _, _ = img2img(args_dict)
processed_frame = np.array(processed_frames[0])

processed_frame = np.clip(processed_frame, 0, 255).astype(np.uint8)
Expand All @@ -287,7 +298,71 @@ def start_process(*args):

return get_cur_stat(), curr_frame, occlusion_mask, warped_styled_frame, processed_frame, ''

'''
def process_img(p, input_img, output_dir, inpaint_mask_dir, args):
processing.fix_seed(p)

#images = shared.listfiles(input_dir)
images = [input_img]

is_inpaint_batch = False
#if inpaint_mask_dir:
# inpaint_masks = shared.listfiles(inpaint_mask_dir)
# is_inpaint_batch = len(inpaint_masks) > 0
#if is_inpaint_batch:
# print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")

#print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")

save_normally = output_dir == ''

p.do_not_save_grid = True
p.do_not_save_samples = not save_normally

state.job_count = len(images) * p.n_iter

generated_images = []
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
state.skipped = False

if state.interrupted:
break

img = image #Image.open(image)
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size

#if is_inpaint_batch:
# # try to find corresponding mask for an image using simple filename matching
# mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# # if not found use first one ("same mask for all images" use-case)
# if not mask_image_path in inpaint_masks:
# mask_image_path = inpaint_masks[0]
# mask_image = Image.open(mask_image_path)
# p.image_mask = mask_image

proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
proc = process_images(p)
generated_images.append(proc.images[0])

#for n, processed_image in enumerate(proc.images):
# filename = os.path.basename(image)

# if n > 0:
# left, right = os.path.splitext(filename)
# filename = f"{left}-{n}{right}"

# if not save_normally:
# os.makedirs(output_dir, exist_ok=True)
# if processed_image.mode == 'RGBA':
# processed_image = processed_image.convert("RGB")
# processed_image.save(os.path.join(output_dir, filename))

return generated_images

# id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles: list, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args
def img2img(args_dict):
args = SimpleNamespace(**args_dict)
Expand Down Expand Up @@ -375,7 +450,8 @@ def img2img(args_dict):

if mask:
p.extra_generation_params["Mask blur"] = args.mask_blur

'''
if is_batch:
...
# assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
Expand All @@ -385,7 +461,10 @@ def img2img(args_dict):
processed = modules.scripts.scripts_img2img.run(p, *args.script_inputs)
if processed is None:
processed = process_images(p)
'''

generated_images = process_img(p, image, None, '', args.script_inputs)
processed = Processed(p, [], p.seed, "")
p.close()

shared.total_tqdm.clear()
Expand All @@ -397,4 +476,4 @@ def img2img(args_dict):
if opts.do_not_show_images:
processed.images = []

return processed.images[0] #, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)'''
return generated_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)

0 comments on commit 9b88a8f

Please sign in to comment.