diff --git a/scripts/deforum_helpers/animation_key_frames.py b/scripts/deforum_helpers/animation_key_frames.py index ca4aa4310..af3126252 100644 --- a/scripts/deforum_helpers/animation_key_frames.py +++ b/scripts/deforum_helpers/animation_key_frames.py @@ -87,6 +87,21 @@ def __init__(self, anim_args, controlnet_args): self.schedules[output_key] = self.fi.parse_inbetweens(getattr(controlnet_args, input_key), input_key) setattr(self, output_key, self.schedules[output_key]) +class AnimateDiffKeys(): + def __init__(self, animatediff_args, anim_args): + self.fi = FrameInterpolater(anim_args.max_frames) + self.enable = animatediff_args.animatediff_enabled + self.model = animatediff_args.animatediff_model + self.activation_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_activation_schedule, 'activation_schedule') + self.motion_lora_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_motion_lora_schedule, 'motion_lora_schedule', is_single_string = True) + self.video_length_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_video_length_schedule, 'video_length_schedule') + self.batch_size_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_batch_size_schedule, 'batch_size_schedule') + self.stride_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_stride_schedule, 'stride_schedule') + self.overlap_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_overlap_schedule, 'overlap_schedule') + self.latent_scale_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_latent_scale_schedule, 'latent_scale_schedule') + self.latent_power_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_latent_power_schedule, 'latent_power_schedule') + self.closed_loop_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_closed_loop_schedule, 'closed_loop_schedule', is_single_string = True) + class LooperAnimKeys(): def __init__(self, loop_args, anim_args, seed): self.fi = FrameInterpolater(anim_args.max_frames, seed) diff --git a/scripts/deforum_helpers/args.py b/scripts/deforum_helpers/args.py index 98257597f..d354c4372 100644 --- a/scripts/deforum_helpers/args.py +++ b/scripts/deforum_helpers/args.py @@ -23,6 +23,7 @@ import modules.shared as sh from modules.processing import get_fixed_seed from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list +from .deforum_animatediff import animatediff_component_names from .deforum_controlnet import controlnet_component_names from .general_utils import get_os, substitute_placeholders @@ -1119,7 +1120,7 @@ def DeforumOutputArgs(): def get_component_names(): return ['override_settings_with_file', 'custom_settings_file', *DeforumAnimArgs().keys(), 'animation_prompts', 'animation_prompts_positive', 'animation_prompts_negative', - *DeforumArgs().keys(), *DeforumOutputArgs().keys(), *ParseqArgs().keys(), *LoopArgs().keys(), *controlnet_component_names()] + *DeforumArgs().keys(), *DeforumOutputArgs().keys(), *ParseqArgs().keys(), *LoopArgs().keys(), *animatediff_component_names(), *controlnet_component_names()] def get_settings_component_names(): return [name for name in get_component_names()] @@ -1139,13 +1140,14 @@ def process_args(args_dict_main, run_id): video_args = SimpleNamespace(**{name: args_dict_main[name] for name in DeforumOutputArgs()}) parseq_args = SimpleNamespace(**{name: args_dict_main[name] for name in ParseqArgs()}) loop_args = SimpleNamespace(**{name: args_dict_main[name] for name in LoopArgs()}) + animatediff_args = SimpleNamespace(**{name: args_dict_main[name] for name in animatediff_component_names()}) controlnet_args = SimpleNamespace(**{name: args_dict_main[name] for name in controlnet_component_names()}) root.animation_prompts = json.loads(args_dict_main['animation_prompts']) args_loaded_ok = True if override_settings_with_file: - args_loaded_ok = load_args(args_dict_main, args, anim_args, parseq_args, loop_args, controlnet_args, video_args, custom_settings_file, root, run_id) + args_loaded_ok = load_args(args_dict_main, args, anim_args, parseq_args, loop_args, animatediff_args, controlnet_args, video_args, custom_settings_file, root, run_id) positive_prompts = args_dict_main['animation_prompts_positive'] negative_prompts = args_dict_main['animation_prompts_negative'] @@ -1184,4 +1186,4 @@ def process_args(args_dict_main, run_id): default_img = default_img.resize((args.W,args.H)) root.default_img = default_img - return args_loaded_ok, root, args, anim_args, video_args, parseq_args, loop_args, controlnet_args + return args_loaded_ok, root, args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args diff --git a/scripts/deforum_helpers/deforum_animatediff.py b/scripts/deforum_helpers/deforum_animatediff.py new file mode 100644 index 000000000..5091586d1 --- /dev/null +++ b/scripts/deforum_helpers/deforum_animatediff.py @@ -0,0 +1,262 @@ +# Copyright (C) 2023 Deforum LLC +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, version 3 of the License. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +# Contact the authors: https://deforum.github.io/ + +# This helper script is responsible for AnimateDiff/Deforum integration +# https://github.com/continue-revolution/sd-webui-animatediff — animatediff repo + +import os +import copy +import gradio as gr +import scripts +from PIL import Image +import numpy as np +import importlib +import shutil +from modules import scripts, shared +from .deforum_controlnet_gradio import hide_ui_by_cn_status, hide_file_textboxes, ToolButton +from .general_utils import count_files_in_folder, clean_gradio_path_strings # TODO: do it another way +from .video_audio_utilities import vid2frames, convert_image +from .animation_key_frames import AnimateDiffKeys +from .load_images import load_image +from .general_utils import debug_print +from modules.shared import opts, cmd_opts, state, sd_model + +import modules.paths as ph + +#self.last_frame = last_frame +#self.latent_power_last = latent_power_last +#self.latent_scale_last = latent_scale_last + +cnet = None + +def find_animatediff(): + global cnet + if cnet: return cnet + try: + cnet = importlib.import_module('extensions.sd-webui-animatediff.scripts', 'animatediff') + except: + try: + cnet = importlib.import_module('extensions-builtin.sd-webui-animatediff.scripts', 'animatediff') + except: + pass + if cnet: + print(f"\033[0;32m*Deforum AnimateDiff support: enabled*\033[0m") + return True + return None + +def is_animatediff_enabled(animatediff_args): + if getattr(animatediff_args, f'animatediff_enabled', False): + return True + return False + +def animatediff_infotext(): + return """**Experimental!** +Requires the AnimateDiff extension to be installed.

+""" + +def animatediff_component_names_raw(): + return [ + 'enabled', 'model', 'activation_schedule', + 'motion_lora_schedule', + 'video_length_schedule', + 'batch_size_schedule', + 'stride_schedule', + 'overlap_schedule', + 'latent_power_schedule', 'latent_scale_schedule', + 'closed_loop_schedule' + ] + +def animatediff_component_names(): + if not find_animatediff(): + return [] + + return [f'animatediff_{i}' for i in animatediff_component_names_raw()] + +def setup_animatediff_ui_raw(): + + cnet = find_animatediff() + + model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(scripts.basedir(), "model")) + + if not os.path.isdir(model_dir): + os.mkdir(model_dir) + + cn_models = [f for f in os.listdir(model_dir) if f != ".gitkeep"] + + def refresh_all_models(*inputs): + new_model_list = [ + f for f in os.listdir(model_dir) if f != ".gitkeep" + ] + dd = inputs[0] + if dd in new_model_list: + selected = dd + elif len(new_model_list) > 0: + selected = new_model_list[0] + else: + selected = None + return gr.Dropdown.update(choices=new_model_list, value=selected) + + refresh_symbol = '\U0001f504' # 🔄 + switch_values_symbol = '\U000021C5' # ⇅ + infotext_fields = [] + + # TODO: unwrap + def create_model_in_tab_ui(cn_id): + with gr.Row(): + gr.Markdown('Note: AnimateDiff will work only if you have ControlNet installed as well') + enabled = gr.Checkbox(label="Enable AnimateDiff", value=False, interactive=True) + with gr.Row(visible=False) as mod_row: + model = gr.Dropdown(cn_models, label=f"Motion module", value="None", interactive=True, tooltip="Choose which motion module will be injected into the generation process.") + refresh_models = ToolButton(value=refresh_symbol) + refresh_models.click(refresh_all_models, model, model) + with gr.Row(visible=False) as inforow: + gr.Markdown('**Important!** This schedule sets up when AnimateDiff should run on the generated N previous frames. At the moment this is made with binary values: when the expression value is 0, it will make a pass, otherwise normal Deforum frames will be made') + with gr.Row(visible=False) as activation_row: + activation_schedule = gr.Textbox(label="AnimateDiff activation schedule", lines=1, value='0:(1), 2:((t-1) % 16)', interactive=True) + gr.Markdown('Internal AnimateDiff settings, see its script in normal tabs') + with gr.Row(visible=False) as motion_lora_row: + motion_lora_schedule = gr.Textbox(label="Motion lora schedule (not supported atm, stay tuned!)", lines=1, value='0:("")', interactive=True) + with gr.Row(visible=False) as length_row: + video_length_schedule = gr.Textbox(label="N-back video length schedule", lines=1, value='0:(16)', interactive=True) + with gr.Row(visible=False) as window_row: + batch_size_schedule = gr.Textbox(label="Batch size", lines=1, value='0:(16)', interactive=True) + with gr.Row(visible=False) as stride_row: + stride_schedule = gr.Textbox(label="Stride", lines=1, value='0:(1)', interactive=True) + with gr.Row(visible=False) as overlap_row: + overlap_schedule = gr.Textbox(label="Overlap", lines=1, value='0:(-1)', interactive=True) + with gr.Row(visible=False) as latent_power_row: + latent_power_schedule = gr.Textbox(label="Latent power schedule", lines=1, value='0:(1)', interactive=True) + with gr.Row(visible=False) as latent_scale_row: + latent_scale_schedule = gr.Textbox(label="Latent scale schedule", lines=1, value='0:(32)', interactive=True) + with gr.Row(visible=False) as rp_row: + closed_loop_schedule = gr.Textbox(label="Closed loop", lines=1, value='0:("R-P")', interactive=True) + hide_output_list = [enabled, inforow, activation_row, motion_lora_row, mod_row, length_row, window_row, stride_row, overlap_row, latent_power_row, latent_scale_row, rp_row] + for cn_output in hide_output_list: + enabled.change(fn=hide_ui_by_cn_status, inputs=enabled, outputs=cn_output) + + infotext_fields.extend([ + (model, f"AnimateDiff Model"), + ]) + + return {key: value for key, value in locals().items() if key in + animatediff_component_names_raw() + } + + with gr.TabItem('AnimateDiff'): + gr.HTML(animatediff_infotext()) + model_params = create_model_in_tab_ui(0) + + for key, value in model_params.items(): + locals()[f"animatediff_{key}"] = value + + return locals() + +def setup_animatediff_ui(): + if not find_animatediff(): + gr.HTML("""AnimateDiff not found. Please install it :)""", elem_id='animatediff_not_found_html_msg') + return {} + + try: + return setup_animatediff_ui_raw() + except Exception as e: + print(f"'AnimateDiff UI setup failed with error: '{e}'!") + gr.HTML(f""" + Failed to setup AnimateDiff UI, check the reason in your commandline log. Please, downgrade your AnimateDiff extension to b192a2551a5ed66d4a3ce58d5d19a8872abc87ca and report the problem here (Deforum) or here (AnimateDiff). + """, elem_id='animatediff_not_found_html_msg') + return {} + +def find_animatediff_script(prev_always_on_scripts): + animatediff_script = next((script for script in prev_always_on_scripts if "animatediff" in script.title().lower()), None) + if not animatediff_script: + raise Exception("AnimateDiff script not found.") + return animatediff_script + +def get_animatediff_temp_dir(args): + return os.path.join(args.outdir, 'animatediff_temp') + +def need_animatediff(animatediff_args): + return find_animatediff() is not None and is_animatediff_enabled(animatediff_args) + +def seed_animatediff(p, prev_always_on_scripts, animatediff_args, args, anim_args, root, frame_idx): + if not need_animatediff(animatediff_args): + return + + keys = AnimateDiffKeys(animatediff_args, anim_args) # if not parseq_adapter.use_parseq else parseq_adapter.cn_keys + + # Will do the back-render only on target frames + if int(keys.activation_schedule_series[frame_idx]) != 0: + return + + video_length = int(keys.video_length_schedule_series[frame_idx]) + assert video_length > 1 + + # Managing the frames to be fed into AD: + # Create a temporal directory + animatediff_temp_dir = get_animatediff_temp_dir(args) + if os.path.exists(animatediff_temp_dir): + shutil.rmtree(animatediff_temp_dir) + os.makedirs(animatediff_temp_dir) + # Copy the frames (except for the one which is being CN-made) into that dir + for offset in range(video_length - 1): + filename = f"{root.timestring}_{frame_idx - offset - 1:09}.png" + Image.open(os.path.join(args.outdir, filename)).save(os.path.join(animatediff_temp_dir, f"{offset:09}.png"), "PNG") + + animatediff_script = find_animatediff_script(prev_always_on_scripts) + # let's put it before ControlNet to cause less problems + p.is_api = True # to parse the params internally + p.scripts.alwayson_scripts = [animatediff_script] + p.scripts.alwayson_scripts + + args_dict = { + 'model': keys.model, # Motion module + 'format': ['PNG', 'Frame'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'WEBM' | 'TXT' | 'Frame' + 'enable': keys.enable, # Enable AnimateDiff + 'video_length': video_length, # Number of frames + 'fps': 8, # FPS - don't care + 'loop_number': 0, # Display loop number + 'closed_loop': keys.closed_loop_schedule_series[frame_idx], # Closed loop, 'N' | 'R-P' | 'R+P' | 'A' + 'batch_size': int(keys.batch_size_schedule_series[frame_idx]), # Context batch size + 'stride': int(keys.stride_schedule_series[frame_idx]), # Stride + 'overlap': int(keys.overlap_schedule_series[frame_idx]), # Overlap + 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' - don't care + 'interp_x': 10, # Interp X - don't care + 'video_source': '', # We don't use a video + 'video_path': animatediff_temp_dir, # Path with our selected video_length input frames + 'latent_power': keys.latent_power_schedule_series[frame_idx], # Latent power + 'latent_scale': keys.latent_scale_schedule_series[frame_idx], # Latent scale + 'last_frame': None, # Optional last frame + 'latent_power_last': 1, # Optional latent power for last frame + 'latent_scale_last': 32,# Optional latent scale for last frame + 'request_id': '' # Optional request id. If provided, outputs will have request id as filename suffix + } + + args = [None] * 10 + [args_dict] # HACK hardcoded args offset + + p.script_args_value = args + p.script_args_value + +def reap_animatediff(images, animatediff_args, args, root, frame_idx): + if not need_animatediff(animatediff_args): + return + + animatediff_temp_dir = get_animatediff_temp_dir(args) + assert os.path.exists(animatediff_temp_dir) + + for offset in range(len(images)): + frame = images[-offset-1] + cur_frame_idx = frame_idx - offset + + # overwrite the results + filename = f"{root.timestring}_{cur_frame_idx:09}.png" + frame.save(os.path.join(args.outdir, filename), "PNG") diff --git a/scripts/deforum_helpers/deforum_controlnet.py b/scripts/deforum_helpers/deforum_controlnet.py index cef20cbc8..729b85036 100644 --- a/scripts/deforum_helpers/deforum_controlnet.py +++ b/scripts/deforum_helpers/deforum_controlnet.py @@ -38,6 +38,9 @@ max_models = shared.opts.data.get("control_net_unit_count", shared.opts.data.get("control_net_max_models_num", 5)) num_of_models = 5 if max_models <= 5 else max_models +# AnimateDiff support (it requires ControlNet anyway) +from .deforum_animatediff import seed_animatediff, is_animatediff_enabled + def find_controlnet(): global cnet if cnet: return cnet @@ -217,7 +220,8 @@ def controlnet_component_names(): 'processor_res', 'threshold_a', 'threshold_b', 'resize_mode', 'control_mode', 'loopback_mode' ]] -def process_with_controlnet(p, args, anim_args, controlnet_args, root, parseq_adapter, is_img2img=True, frame_idx=0): +def process_with_controlnet(p, args, anim_args, controlnet_args, animatediff_args, root, parseq_adapter, is_img2img=True, frame_idx=0): + p.do_not_save_grid = True CnSchKeys = ControlNetKeys(anim_args, controlnet_args) if not parseq_adapter.use_parseq else parseq_adapter.cn_keys def read_cn_data(cn_idx): @@ -264,7 +268,7 @@ def read_cn_data(cn_idx): cn_inputframes_list = [os.path.join(args.outdir, f'controlnet_{i}_inputframes') for i in range(1, num_of_models + 1)] - if not any(os.path.exists(cn_inputframes) for cn_inputframes in cn_inputframes_list) and not any_loopback_mode: + if not any(os.path.exists(cn_inputframes) for cn_inputframes in cn_inputframes_list) and not any_loopback_mode and not is_animatediff_enabled(animatediff_args): print(f'\033[33mNeither the base nor the masking frames for ControlNet were found. Using the regular pipeline\033[0m') # Remove all scripts except controlnet. @@ -284,11 +288,15 @@ def read_cn_data(cn_idx): # p.scripts = copy.copy(scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img) controlnet_script = find_controlnet_script(p) + prev_always_on_scripts = p.scripts.alwayson_scripts p.scripts.alwayson_scripts = [controlnet_script] # Filling the list with None is safe because only the length will be considered, # and all cn args will be replaced. p.script_args_value = [None] * controlnet_script.args_to + # Basically, launch AD on a number of previous frames once it hits the seed time + seed_animatediff(p, prev_always_on_scripts, animatediff_args, args, anim_args, root, frame_idx) + def create_cnu_dict(cn_args, prefix, img_np, mask_np, frame_idx, CnSchKeys): keys = [ diff --git a/scripts/deforum_helpers/generate.py b/scripts/deforum_helpers/generate.py index fe48055e3..261a63987 100644 --- a/scripts/deforum_helpers/generate.py +++ b/scripts/deforum_helpers/generate.py @@ -36,6 +36,7 @@ from types import SimpleNamespace from .general_utils import debug_print +from .deforum_animatediff import reap_animatediff, is_animatediff_enabled def load_mask_latent(mask_input, shape): # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object @@ -70,14 +71,14 @@ def pairwise_repl(iterable): next(b, None) return zip(a, b) -def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None): +def generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame=0, sampler_name=None): if state.interrupted: return None if args.reroll_blank_frames == 'ignore': - return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + return generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name) - image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name) if caught_vae_exception or not image.getbbox(): patience = args.reroll_patience @@ -86,7 +87,7 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada while caught_vae_exception or not image.getbbox(): print("Rerolling with +1 seed...") args.seed += 1 - image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name) patience -= 1 if patience == 0: print("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting...") @@ -100,12 +101,12 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada return None return image -def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None): +def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame=0, sampler_name=None): if cmd_opts.disable_nan_check: - image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name) else: try: - image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name) + image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name) except Exception as e: if "A tensor with all NaNs was produced in VAE." in repr(e): print(e) @@ -114,7 +115,7 @@ def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, raise e return image, False -def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None): +def generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame=0, sampler_name=None): # Setup the pipeline p = get_webui_sd_pipeline(args, root) p.prompt, p.negative_prompt = split_weighted_subprompts(args.prompt, frame, anim_args.max_frames) @@ -234,8 +235,8 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars print_combined_table(args, anim_args, p_txt, keys, frame) # print dynamic table to cli - if is_controlnet_enabled(controlnet_args): - process_with_controlnet(p_txt, args, anim_args, controlnet_args, root, parseq_adapter, is_img2img=False, frame_idx=frame) + if is_controlnet_enabled(controlnet_args) or is_animatediff_enabled(animatediff_args): + process_with_controlnet(p_txt, args, anim_args, controlnet_args, animatediff_args, root, parseq_adapter, is_img2img=False, frame_idx=frame) with A1111OptionsOverrider({"control_net_detectedmap_dir" : os.path.join(args.outdir, "controlnet_detected_map")}): processed = processing.process_images(p_txt) @@ -276,8 +277,8 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars if args.motion_preview_mode: processed = mock_process_images(args, p, init_image) else: - if is_controlnet_enabled(controlnet_args): - process_with_controlnet(p, args, anim_args, controlnet_args, root, parseq_adapter, is_img2img=True, frame_idx=frame) + if is_controlnet_enabled(controlnet_args) or is_animatediff_enabled(animatediff_args): + process_with_controlnet(p, args, anim_args, controlnet_args, animatediff_args, root, parseq_adapter, is_img2img=True, frame_idx=frame) with A1111OptionsOverrider({"control_net_detectedmap_dir" : os.path.join(args.outdir, "controlnet_detected_map")}): processed = processing.process_images(p) @@ -287,9 +288,12 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars root.initial_info = processed.info if root.first_frame is None: - root.first_frame = processed.images[0] + root.first_frame = processed.images[-1] - results = processed.images[0] + results = processed.images[-1] # AD uses ascending order, so we need to get the last frame + + if len(processed.images) > 1: + reap_animatediff(processed.images, animatediff_args, args, root, frame) return results diff --git a/scripts/deforum_helpers/render.py b/scripts/deforum_helpers/render.py index 964498c3a..247a5ef80 100644 --- a/scripts/deforum_helpers/render.py +++ b/scripts/deforum_helpers/render.py @@ -52,11 +52,15 @@ from deforum_api import JobStatusTracker -def render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root): +def render_animation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root): # initialise Parseq adapter + # TODO: @rewbs parseq_adapter = ParseqAdapter(parseq_args, anim_args, video_args, controlnet_args, loop_args) + if animatediff_args.animatediff_enabled: + print("*Rendering with AnimateDiff turned on. (Experimental!)*") + if opts.data.get("deforum_save_gen_info_as_srt", False): # create .srt file and set timeframe mechanism using FPS srt_filename = os.path.join(args.outdir, f"{root.timestring}.srt") srt_frame_duration = init_srt_file(srt_filename, video_args.fps) @@ -93,7 +97,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, contro print(f"Saving animation frames to:\n{args.outdir}") # save settings.txt file for the current run - save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, video_args, root) + save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, animatediff_args, video_args, root) # resume from timestring if anim_args.resume_from_timestring: @@ -547,7 +551,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, contro args.seed = random.randint(0, 2 ** 32 - 1) print(f"Optical flow redo is diffusing and warping using {optical_flow_redo_generation} and seed {args.seed} optical flow before generation.") - disposable_image = generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name) + disposable_image = generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name) disposable_image = cv2.cvtColor(np.array(disposable_image), cv2.COLOR_RGB2BGR) disposable_flow = get_flow_from_images(prev_img, disposable_image, optical_flow_redo_generation, raft_model) disposable_image = cv2.cvtColor(disposable_image, cv2.COLOR_BGR2RGB) @@ -563,7 +567,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, contro for n in range(0, int(anim_args.diffusion_redo)): print(f"Redo generation {n + 1} of {int(anim_args.diffusion_redo)} before final generation") args.seed = random.randint(0, 2 ** 32 - 1) - disposable_image = generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name) + disposable_image = generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name) disposable_image = cv2.cvtColor(np.array(disposable_image), cv2.COLOR_RGB2BGR) # color match on last one only if n == int(anim_args.diffusion_redo): @@ -574,7 +578,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, contro gc.collect() # generation - image = generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name) + image = generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name) if image is None: break diff --git a/scripts/deforum_helpers/render_modes.py b/scripts/deforum_helpers/render_modes.py index 46730c991..388ecfa76 100644 --- a/scripts/deforum_helpers/render_modes.py +++ b/scripts/deforum_helpers/render_modes.py @@ -30,7 +30,7 @@ from .save_images import save_image from .settings import save_settings_from_animation_run -def render_input_video(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root): +def render_input_video(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root): # create a folder for the video input frames to live in video_in_frame_path = os.path.join(args.outdir, 'inputframes') os.makedirs(video_in_frame_path, exist_ok=True) @@ -61,10 +61,10 @@ def render_input_video(args, anim_args, video_args, parseq_args, loop_args, cont args.use_mask = True args.overlay_mask = True - render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root) + render_animation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root) # Modified a copy of the above to allow using masking video with out a init video. -def render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root): +def render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root): # create a folder for the video input frames to live in mask_in_frame_path = os.path.join(args.outdir, 'maskframes') os.makedirs(mask_in_frame_path, exist_ok=True) @@ -80,7 +80,7 @@ def render_animation_with_video_mask(args, anim_args, video_args, parseq_args, l #args.use_init = True print(f"Loading {anim_args.max_frames} input frames from {mask_in_frame_path} and saving video frames to {args.outdir}") - render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root) + render_animation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root) def get_parsed_value(value, frame_idx, max_f): pattern = r'`.*?`' @@ -93,7 +93,7 @@ def get_parsed_value(value, frame_idx, max_f): parsed_value = parsed_value.replace(matched_string, str(value)) return parsed_value -def render_interpolation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root): +def render_interpolation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root): # use parseq if manifest is provided parseq_adapter = ParseqAdapter(parseq_args, anim_args, video_args, controlnet_args, loop_args) @@ -106,7 +106,7 @@ def render_interpolation(args, anim_args, video_args, parseq_args, loop_args, co print(f"Saving interpolation animation frames to {args.outdir}") # save settings.txt file for the current run - save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, video_args, root) + save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, animatediff_args, video_args, root) # Compute interpolated prompts if parseq_adapter.manages_prompts(): @@ -162,7 +162,7 @@ def render_interpolation(args, anim_args, video_args, parseq_args, loop_args, co args.seed = int(keys.seed_schedule_series[frame_idx]) if (args.seed_behavior == 'schedule' or parseq_adapter.manages_seed()) else args.seed opts.data["CLIP_stop_at_last_layers"] = scheduled_clipskip if scheduled_clipskip is not None else opts.data["CLIP_stop_at_last_layers"] - image = generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name) + image = generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name) filename = f"{root.timestring}_{frame_idx:09}.png" save_image(image, 'PIL', filename, args, video_args, root) diff --git a/scripts/deforum_helpers/run_deforum.py b/scripts/deforum_helpers/run_deforum.py index ca25d152b..5065c81eb 100644 --- a/scripts/deforum_helpers/run_deforum.py +++ b/scripts/deforum_helpers/run_deforum.py @@ -71,7 +71,7 @@ def run_deforum(*args): args_dict['self'] = None args_dict['p'] = p try: - args_loaded_ok, root, args, anim_args, video_args, parseq_args, loop_args, controlnet_args = process_args(args_dict, i) + args_loaded_ok, root, args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args = process_args(args_dict, i) except Exception as e: JobStatusTracker().fail_job(job_id, error_type="TERMINAL", message="Invalid arguments.") print("\n*START OF TRACEBACK*") @@ -111,13 +111,13 @@ def run_deforum(*args): JobStatusTracker().update_output_info(job_id, outdir=args.outdir, timestring=root.timestring) if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D': if anim_args.use_mask_video: - render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root) # allow mask video without an input video + render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root) # allow mask video without an input video else: - render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root) + render_animation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root) elif anim_args.animation_mode == 'Video Input': - render_input_video(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root)#TODO: prettify code + render_input_video(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root)#TODO: prettify code elif anim_args.animation_mode == 'Interpolation': - render_interpolation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root) + render_interpolation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root) else: print('Other modes are not available yet!') except Exception as e: @@ -221,7 +221,7 @@ def run_deforum(*args): if shared.opts.data.get("deforum_enable_persistent_settings", False): persistent_sett_path = shared.opts.data.get("deforum_persistent_settings_path") - save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, video_args, root, persistent_sett_path) + save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, animatediff_args, controlnet_args, video_args, root, persistent_sett_path) # Close the pipeline, not to interfere with ControlNet try: diff --git a/scripts/deforum_helpers/settings.py b/scripts/deforum_helpers/settings.py index bd1132d9b..214f5821f 100644 --- a/scripts/deforum_helpers/settings.py +++ b/scripts/deforum_helpers/settings.py @@ -58,7 +58,7 @@ def load_args(args_dict_main, args, anim_args, parseq_args, loop_args, controlne return True # save settings function that get calls when run_deforum is being called -def save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, video_args, root, full_out_file_path = None): +def save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, animatediff_args, video_args, root, full_out_file_path = None): if full_out_file_path: args.__dict__["seed"] = root.raw_seed args.__dict__["batch_name"] = root.raw_batch_name @@ -69,7 +69,7 @@ def save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, co settings_filename = full_out_file_path if full_out_file_path else os.path.join(args.outdir, f"{root.timestring}_settings.txt") with open(settings_filename, "w+", encoding="utf-8") as f: s = {} - for d in (args.__dict__, anim_args.__dict__, parseq_args.__dict__, loop_args.__dict__, controlnet_args.__dict__, video_args.__dict__): + for d in (args.__dict__, anim_args.__dict__, parseq_args.__dict__, loop_args.__dict__, controlnet_args.__dict__, animatediff_args.__dict__, video_args.__dict__): s.update({k: v for k, v in d.items() if k not in exclude_keys}) s["sd_model_name"] = sh.sd_model.sd_checkpoint_info.name s["sd_model_hash"] = sh.sd_model.sd_checkpoint_info.hash diff --git a/scripts/deforum_helpers/ui_left.py b/scripts/deforum_helpers/ui_left.py index aa1886372..57b4df03a 100644 --- a/scripts/deforum_helpers/ui_left.py +++ b/scripts/deforum_helpers/ui_left.py @@ -20,6 +20,7 @@ from .gradio_funcs import change_css, handle_change_functions from .args import DeforumArgs, DeforumAnimArgs, ParseqArgs, DeforumOutputArgs, RootArgs, LoopArgs from .deforum_controlnet import setup_controlnet_ui +from .deforum_animatediff import setup_animatediff_ui from .ui_elements import get_tab_run, get_tab_keyframes, get_tab_prompts, get_tab_init, get_tab_hybrid, get_tab_output def set_arg_lists(): @@ -47,11 +48,12 @@ def setup_deforum_left_side_ui(): tab_keyframes_params = get_tab_keyframes(d, da, dloopArgs) # Keyframes tab tab_prompts_params = get_tab_prompts(da) # Prompts tab tab_init_params = get_tab_init(d, da, dp) # Init tab + animatediff_dict = setup_animatediff_ui() # AnimateDiff tab controlnet_dict = setup_controlnet_ui() # ControlNet tab tab_hybrid_params = get_tab_hybrid(da) # Hybrid tab tab_output_params = get_tab_output(da, dv) # Output tab # add returned gradio elements from main tabs to locals() - for key, value in {**tab_run_params, **tab_keyframes_params, **tab_prompts_params, **tab_init_params, **controlnet_dict, **tab_hybrid_params, **tab_output_params}.items(): + for key, value in {**tab_run_params, **tab_keyframes_params, **tab_prompts_params, **tab_init_params, **animatediff_dict, **controlnet_dict, **tab_hybrid_params, **tab_output_params}.items(): locals()[key] = value # Gradio's Change functions - hiding and renaming elements based on other elements