diff --git a/scripts/deforum_helpers/animation_key_frames.py b/scripts/deforum_helpers/animation_key_frames.py
index ca4aa4310..af3126252 100644
--- a/scripts/deforum_helpers/animation_key_frames.py
+++ b/scripts/deforum_helpers/animation_key_frames.py
@@ -87,6 +87,21 @@ def __init__(self, anim_args, controlnet_args):
self.schedules[output_key] = self.fi.parse_inbetweens(getattr(controlnet_args, input_key), input_key)
setattr(self, output_key, self.schedules[output_key])
+class AnimateDiffKeys():
+ def __init__(self, animatediff_args, anim_args):
+ self.fi = FrameInterpolater(anim_args.max_frames)
+ self.enable = animatediff_args.animatediff_enabled
+ self.model = animatediff_args.animatediff_model
+ self.activation_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_activation_schedule, 'activation_schedule')
+ self.motion_lora_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_motion_lora_schedule, 'motion_lora_schedule', is_single_string = True)
+ self.video_length_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_video_length_schedule, 'video_length_schedule')
+ self.batch_size_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_batch_size_schedule, 'batch_size_schedule')
+ self.stride_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_stride_schedule, 'stride_schedule')
+ self.overlap_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_overlap_schedule, 'overlap_schedule')
+ self.latent_scale_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_latent_scale_schedule, 'latent_scale_schedule')
+ self.latent_power_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_latent_power_schedule, 'latent_power_schedule')
+ self.closed_loop_schedule_series = self.fi.parse_inbetweens(animatediff_args.animatediff_closed_loop_schedule, 'closed_loop_schedule', is_single_string = True)
+
class LooperAnimKeys():
def __init__(self, loop_args, anim_args, seed):
self.fi = FrameInterpolater(anim_args.max_frames, seed)
diff --git a/scripts/deforum_helpers/args.py b/scripts/deforum_helpers/args.py
index 98257597f..d354c4372 100644
--- a/scripts/deforum_helpers/args.py
+++ b/scripts/deforum_helpers/args.py
@@ -23,6 +23,7 @@
import modules.shared as sh
from modules.processing import get_fixed_seed
from .defaults import get_guided_imgs_default_json, mask_fill_choices, get_samplers_list
+from .deforum_animatediff import animatediff_component_names
from .deforum_controlnet import controlnet_component_names
from .general_utils import get_os, substitute_placeholders
@@ -1119,7 +1120,7 @@ def DeforumOutputArgs():
def get_component_names():
return ['override_settings_with_file', 'custom_settings_file', *DeforumAnimArgs().keys(), 'animation_prompts', 'animation_prompts_positive', 'animation_prompts_negative',
- *DeforumArgs().keys(), *DeforumOutputArgs().keys(), *ParseqArgs().keys(), *LoopArgs().keys(), *controlnet_component_names()]
+ *DeforumArgs().keys(), *DeforumOutputArgs().keys(), *ParseqArgs().keys(), *LoopArgs().keys(), *animatediff_component_names(), *controlnet_component_names()]
def get_settings_component_names():
return [name for name in get_component_names()]
@@ -1139,13 +1140,14 @@ def process_args(args_dict_main, run_id):
video_args = SimpleNamespace(**{name: args_dict_main[name] for name in DeforumOutputArgs()})
parseq_args = SimpleNamespace(**{name: args_dict_main[name] for name in ParseqArgs()})
loop_args = SimpleNamespace(**{name: args_dict_main[name] for name in LoopArgs()})
+ animatediff_args = SimpleNamespace(**{name: args_dict_main[name] for name in animatediff_component_names()})
controlnet_args = SimpleNamespace(**{name: args_dict_main[name] for name in controlnet_component_names()})
root.animation_prompts = json.loads(args_dict_main['animation_prompts'])
args_loaded_ok = True
if override_settings_with_file:
- args_loaded_ok = load_args(args_dict_main, args, anim_args, parseq_args, loop_args, controlnet_args, video_args, custom_settings_file, root, run_id)
+ args_loaded_ok = load_args(args_dict_main, args, anim_args, parseq_args, loop_args, animatediff_args, controlnet_args, video_args, custom_settings_file, root, run_id)
positive_prompts = args_dict_main['animation_prompts_positive']
negative_prompts = args_dict_main['animation_prompts_negative']
@@ -1184,4 +1186,4 @@ def process_args(args_dict_main, run_id):
default_img = default_img.resize((args.W,args.H))
root.default_img = default_img
- return args_loaded_ok, root, args, anim_args, video_args, parseq_args, loop_args, controlnet_args
+ return args_loaded_ok, root, args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args
diff --git a/scripts/deforum_helpers/deforum_animatediff.py b/scripts/deforum_helpers/deforum_animatediff.py
new file mode 100644
index 000000000..5091586d1
--- /dev/null
+++ b/scripts/deforum_helpers/deforum_animatediff.py
@@ -0,0 +1,262 @@
+# Copyright (C) 2023 Deforum LLC
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, version 3 of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+# Contact the authors: https://deforum.github.io/
+
+# This helper script is responsible for AnimateDiff/Deforum integration
+# https://github.com/continue-revolution/sd-webui-animatediff — animatediff repo
+
+import os
+import copy
+import gradio as gr
+import scripts
+from PIL import Image
+import numpy as np
+import importlib
+import shutil
+from modules import scripts, shared
+from .deforum_controlnet_gradio import hide_ui_by_cn_status, hide_file_textboxes, ToolButton
+from .general_utils import count_files_in_folder, clean_gradio_path_strings # TODO: do it another way
+from .video_audio_utilities import vid2frames, convert_image
+from .animation_key_frames import AnimateDiffKeys
+from .load_images import load_image
+from .general_utils import debug_print
+from modules.shared import opts, cmd_opts, state, sd_model
+
+import modules.paths as ph
+
+#self.last_frame = last_frame
+#self.latent_power_last = latent_power_last
+#self.latent_scale_last = latent_scale_last
+
+cnet = None
+
+def find_animatediff():
+ global cnet
+ if cnet: return cnet
+ try:
+ cnet = importlib.import_module('extensions.sd-webui-animatediff.scripts', 'animatediff')
+ except:
+ try:
+ cnet = importlib.import_module('extensions-builtin.sd-webui-animatediff.scripts', 'animatediff')
+ except:
+ pass
+ if cnet:
+ print(f"\033[0;32m*Deforum AnimateDiff support: enabled*\033[0m")
+ return True
+ return None
+
+def is_animatediff_enabled(animatediff_args):
+ if getattr(animatediff_args, f'animatediff_enabled', False):
+ return True
+ return False
+
+def animatediff_infotext():
+ return """**Experimental!**
+Requires the AnimateDiff extension to be installed.
+"""
+
+def animatediff_component_names_raw():
+ return [
+ 'enabled', 'model', 'activation_schedule',
+ 'motion_lora_schedule',
+ 'video_length_schedule',
+ 'batch_size_schedule',
+ 'stride_schedule',
+ 'overlap_schedule',
+ 'latent_power_schedule', 'latent_scale_schedule',
+ 'closed_loop_schedule'
+ ]
+
+def animatediff_component_names():
+ if not find_animatediff():
+ return []
+
+ return [f'animatediff_{i}' for i in animatediff_component_names_raw()]
+
+def setup_animatediff_ui_raw():
+
+ cnet = find_animatediff()
+
+ model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(scripts.basedir(), "model"))
+
+ if not os.path.isdir(model_dir):
+ os.mkdir(model_dir)
+
+ cn_models = [f for f in os.listdir(model_dir) if f != ".gitkeep"]
+
+ def refresh_all_models(*inputs):
+ new_model_list = [
+ f for f in os.listdir(model_dir) if f != ".gitkeep"
+ ]
+ dd = inputs[0]
+ if dd in new_model_list:
+ selected = dd
+ elif len(new_model_list) > 0:
+ selected = new_model_list[0]
+ else:
+ selected = None
+ return gr.Dropdown.update(choices=new_model_list, value=selected)
+
+ refresh_symbol = '\U0001f504' # 🔄
+ switch_values_symbol = '\U000021C5' # ⇅
+ infotext_fields = []
+
+ # TODO: unwrap
+ def create_model_in_tab_ui(cn_id):
+ with gr.Row():
+ gr.Markdown('Note: AnimateDiff will work only if you have ControlNet installed as well')
+ enabled = gr.Checkbox(label="Enable AnimateDiff", value=False, interactive=True)
+ with gr.Row(visible=False) as mod_row:
+ model = gr.Dropdown(cn_models, label=f"Motion module", value="None", interactive=True, tooltip="Choose which motion module will be injected into the generation process.")
+ refresh_models = ToolButton(value=refresh_symbol)
+ refresh_models.click(refresh_all_models, model, model)
+ with gr.Row(visible=False) as inforow:
+ gr.Markdown('**Important!** This schedule sets up when AnimateDiff should run on the generated N previous frames. At the moment this is made with binary values: when the expression value is 0, it will make a pass, otherwise normal Deforum frames will be made')
+ with gr.Row(visible=False) as activation_row:
+ activation_schedule = gr.Textbox(label="AnimateDiff activation schedule", lines=1, value='0:(1), 2:((t-1) % 16)', interactive=True)
+ gr.Markdown('Internal AnimateDiff settings, see its script in normal tabs')
+ with gr.Row(visible=False) as motion_lora_row:
+ motion_lora_schedule = gr.Textbox(label="Motion lora schedule (not supported atm, stay tuned!)", lines=1, value='0:("")', interactive=True)
+ with gr.Row(visible=False) as length_row:
+ video_length_schedule = gr.Textbox(label="N-back video length schedule", lines=1, value='0:(16)', interactive=True)
+ with gr.Row(visible=False) as window_row:
+ batch_size_schedule = gr.Textbox(label="Batch size", lines=1, value='0:(16)', interactive=True)
+ with gr.Row(visible=False) as stride_row:
+ stride_schedule = gr.Textbox(label="Stride", lines=1, value='0:(1)', interactive=True)
+ with gr.Row(visible=False) as overlap_row:
+ overlap_schedule = gr.Textbox(label="Overlap", lines=1, value='0:(-1)', interactive=True)
+ with gr.Row(visible=False) as latent_power_row:
+ latent_power_schedule = gr.Textbox(label="Latent power schedule", lines=1, value='0:(1)', interactive=True)
+ with gr.Row(visible=False) as latent_scale_row:
+ latent_scale_schedule = gr.Textbox(label="Latent scale schedule", lines=1, value='0:(32)', interactive=True)
+ with gr.Row(visible=False) as rp_row:
+ closed_loop_schedule = gr.Textbox(label="Closed loop", lines=1, value='0:("R-P")', interactive=True)
+ hide_output_list = [enabled, inforow, activation_row, motion_lora_row, mod_row, length_row, window_row, stride_row, overlap_row, latent_power_row, latent_scale_row, rp_row]
+ for cn_output in hide_output_list:
+ enabled.change(fn=hide_ui_by_cn_status, inputs=enabled, outputs=cn_output)
+
+ infotext_fields.extend([
+ (model, f"AnimateDiff Model"),
+ ])
+
+ return {key: value for key, value in locals().items() if key in
+ animatediff_component_names_raw()
+ }
+
+ with gr.TabItem('AnimateDiff'):
+ gr.HTML(animatediff_infotext())
+ model_params = create_model_in_tab_ui(0)
+
+ for key, value in model_params.items():
+ locals()[f"animatediff_{key}"] = value
+
+ return locals()
+
+def setup_animatediff_ui():
+ if not find_animatediff():
+ gr.HTML("""AnimateDiff not found. Please install it :)""", elem_id='animatediff_not_found_html_msg')
+ return {}
+
+ try:
+ return setup_animatediff_ui_raw()
+ except Exception as e:
+ print(f"'AnimateDiff UI setup failed with error: '{e}'!")
+ gr.HTML(f"""
+ Failed to setup AnimateDiff UI, check the reason in your commandline log. Please, downgrade your AnimateDiff extension to b192a2551a5ed66d4a3ce58d5d19a8872abc87ca and report the problem here (Deforum) or here (AnimateDiff).
+ """, elem_id='animatediff_not_found_html_msg')
+ return {}
+
+def find_animatediff_script(prev_always_on_scripts):
+ animatediff_script = next((script for script in prev_always_on_scripts if "animatediff" in script.title().lower()), None)
+ if not animatediff_script:
+ raise Exception("AnimateDiff script not found.")
+ return animatediff_script
+
+def get_animatediff_temp_dir(args):
+ return os.path.join(args.outdir, 'animatediff_temp')
+
+def need_animatediff(animatediff_args):
+ return find_animatediff() is not None and is_animatediff_enabled(animatediff_args)
+
+def seed_animatediff(p, prev_always_on_scripts, animatediff_args, args, anim_args, root, frame_idx):
+ if not need_animatediff(animatediff_args):
+ return
+
+ keys = AnimateDiffKeys(animatediff_args, anim_args) # if not parseq_adapter.use_parseq else parseq_adapter.cn_keys
+
+ # Will do the back-render only on target frames
+ if int(keys.activation_schedule_series[frame_idx]) != 0:
+ return
+
+ video_length = int(keys.video_length_schedule_series[frame_idx])
+ assert video_length > 1
+
+ # Managing the frames to be fed into AD:
+ # Create a temporal directory
+ animatediff_temp_dir = get_animatediff_temp_dir(args)
+ if os.path.exists(animatediff_temp_dir):
+ shutil.rmtree(animatediff_temp_dir)
+ os.makedirs(animatediff_temp_dir)
+ # Copy the frames (except for the one which is being CN-made) into that dir
+ for offset in range(video_length - 1):
+ filename = f"{root.timestring}_{frame_idx - offset - 1:09}.png"
+ Image.open(os.path.join(args.outdir, filename)).save(os.path.join(animatediff_temp_dir, f"{offset:09}.png"), "PNG")
+
+ animatediff_script = find_animatediff_script(prev_always_on_scripts)
+ # let's put it before ControlNet to cause less problems
+ p.is_api = True # to parse the params internally
+ p.scripts.alwayson_scripts = [animatediff_script] + p.scripts.alwayson_scripts
+
+ args_dict = {
+ 'model': keys.model, # Motion module
+ 'format': ['PNG', 'Frame'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'WEBM' | 'TXT' | 'Frame'
+ 'enable': keys.enable, # Enable AnimateDiff
+ 'video_length': video_length, # Number of frames
+ 'fps': 8, # FPS - don't care
+ 'loop_number': 0, # Display loop number
+ 'closed_loop': keys.closed_loop_schedule_series[frame_idx], # Closed loop, 'N' | 'R-P' | 'R+P' | 'A'
+ 'batch_size': int(keys.batch_size_schedule_series[frame_idx]), # Context batch size
+ 'stride': int(keys.stride_schedule_series[frame_idx]), # Stride
+ 'overlap': int(keys.overlap_schedule_series[frame_idx]), # Overlap
+ 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM' - don't care
+ 'interp_x': 10, # Interp X - don't care
+ 'video_source': '', # We don't use a video
+ 'video_path': animatediff_temp_dir, # Path with our selected video_length input frames
+ 'latent_power': keys.latent_power_schedule_series[frame_idx], # Latent power
+ 'latent_scale': keys.latent_scale_schedule_series[frame_idx], # Latent scale
+ 'last_frame': None, # Optional last frame
+ 'latent_power_last': 1, # Optional latent power for last frame
+ 'latent_scale_last': 32,# Optional latent scale for last frame
+ 'request_id': '' # Optional request id. If provided, outputs will have request id as filename suffix
+ }
+
+ args = [None] * 10 + [args_dict] # HACK hardcoded args offset
+
+ p.script_args_value = args + p.script_args_value
+
+def reap_animatediff(images, animatediff_args, args, root, frame_idx):
+ if not need_animatediff(animatediff_args):
+ return
+
+ animatediff_temp_dir = get_animatediff_temp_dir(args)
+ assert os.path.exists(animatediff_temp_dir)
+
+ for offset in range(len(images)):
+ frame = images[-offset-1]
+ cur_frame_idx = frame_idx - offset
+
+ # overwrite the results
+ filename = f"{root.timestring}_{cur_frame_idx:09}.png"
+ frame.save(os.path.join(args.outdir, filename), "PNG")
diff --git a/scripts/deforum_helpers/deforum_controlnet.py b/scripts/deforum_helpers/deforum_controlnet.py
index cef20cbc8..729b85036 100644
--- a/scripts/deforum_helpers/deforum_controlnet.py
+++ b/scripts/deforum_helpers/deforum_controlnet.py
@@ -38,6 +38,9 @@
max_models = shared.opts.data.get("control_net_unit_count", shared.opts.data.get("control_net_max_models_num", 5))
num_of_models = 5 if max_models <= 5 else max_models
+# AnimateDiff support (it requires ControlNet anyway)
+from .deforum_animatediff import seed_animatediff, is_animatediff_enabled
+
def find_controlnet():
global cnet
if cnet: return cnet
@@ -217,7 +220,8 @@ def controlnet_component_names():
'processor_res', 'threshold_a', 'threshold_b', 'resize_mode', 'control_mode', 'loopback_mode'
]]
-def process_with_controlnet(p, args, anim_args, controlnet_args, root, parseq_adapter, is_img2img=True, frame_idx=0):
+def process_with_controlnet(p, args, anim_args, controlnet_args, animatediff_args, root, parseq_adapter, is_img2img=True, frame_idx=0):
+ p.do_not_save_grid = True
CnSchKeys = ControlNetKeys(anim_args, controlnet_args) if not parseq_adapter.use_parseq else parseq_adapter.cn_keys
def read_cn_data(cn_idx):
@@ -264,7 +268,7 @@ def read_cn_data(cn_idx):
cn_inputframes_list = [os.path.join(args.outdir, f'controlnet_{i}_inputframes') for i in range(1, num_of_models + 1)]
- if not any(os.path.exists(cn_inputframes) for cn_inputframes in cn_inputframes_list) and not any_loopback_mode:
+ if not any(os.path.exists(cn_inputframes) for cn_inputframes in cn_inputframes_list) and not any_loopback_mode and not is_animatediff_enabled(animatediff_args):
print(f'\033[33mNeither the base nor the masking frames for ControlNet were found. Using the regular pipeline\033[0m')
# Remove all scripts except controlnet.
@@ -284,11 +288,15 @@ def read_cn_data(cn_idx):
#
p.scripts = copy.copy(scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img)
controlnet_script = find_controlnet_script(p)
+ prev_always_on_scripts = p.scripts.alwayson_scripts
p.scripts.alwayson_scripts = [controlnet_script]
# Filling the list with None is safe because only the length will be considered,
# and all cn args will be replaced.
p.script_args_value = [None] * controlnet_script.args_to
+ # Basically, launch AD on a number of previous frames once it hits the seed time
+ seed_animatediff(p, prev_always_on_scripts, animatediff_args, args, anim_args, root, frame_idx)
+
def create_cnu_dict(cn_args, prefix, img_np, mask_np, frame_idx, CnSchKeys):
keys = [
diff --git a/scripts/deforum_helpers/generate.py b/scripts/deforum_helpers/generate.py
index fe48055e3..261a63987 100644
--- a/scripts/deforum_helpers/generate.py
+++ b/scripts/deforum_helpers/generate.py
@@ -36,6 +36,7 @@
from types import SimpleNamespace
from .general_utils import debug_print
+from .deforum_animatediff import reap_animatediff, is_animatediff_enabled
def load_mask_latent(mask_input, shape):
# mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
@@ -70,14 +71,14 @@ def pairwise_repl(iterable):
next(b, None)
return zip(a, b)
-def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
+def generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame=0, sampler_name=None):
if state.interrupted:
return None
if args.reroll_blank_frames == 'ignore':
- return generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
+ return generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)
- image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
+ image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)
if caught_vae_exception or not image.getbbox():
patience = args.reroll_patience
@@ -86,7 +87,7 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
while caught_vae_exception or not image.getbbox():
print("Rerolling with +1 seed...")
args.seed += 1
- image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
+ image, caught_vae_exception = generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)
patience -= 1
if patience == 0:
print("Rerolling with +1 seed failed for 10 iterations! Try setting webui's precision to 'full' and if it fails, please report this to the devs! Interrupting...")
@@ -100,12 +101,12 @@ def generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_ada
return None
return image
-def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
+def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame=0, sampler_name=None):
if cmd_opts.disable_nan_check:
- image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
+ image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)
else:
try:
- image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame, sampler_name)
+ image = generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame, sampler_name)
except Exception as e:
if "A tensor with all NaNs was produced in VAE." in repr(e):
print(e)
@@ -114,7 +115,7 @@ def generate_with_nans_check(args, keys, anim_args, loop_args, controlnet_args,
raise e
return image, False
-def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame=0, sampler_name=None):
+def generate_inner(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame=0, sampler_name=None):
# Setup the pipeline
p = get_webui_sd_pipeline(args, root)
p.prompt, p.negative_prompt = split_weighted_subprompts(args.prompt, frame, anim_args.max_frames)
@@ -234,8 +235,8 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
print_combined_table(args, anim_args, p_txt, keys, frame) # print dynamic table to cli
- if is_controlnet_enabled(controlnet_args):
- process_with_controlnet(p_txt, args, anim_args, controlnet_args, root, parseq_adapter, is_img2img=False, frame_idx=frame)
+ if is_controlnet_enabled(controlnet_args) or is_animatediff_enabled(animatediff_args):
+ process_with_controlnet(p_txt, args, anim_args, controlnet_args, animatediff_args, root, parseq_adapter, is_img2img=False, frame_idx=frame)
with A1111OptionsOverrider({"control_net_detectedmap_dir" : os.path.join(args.outdir, "controlnet_detected_map")}):
processed = processing.process_images(p_txt)
@@ -276,8 +277,8 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
if args.motion_preview_mode:
processed = mock_process_images(args, p, init_image)
else:
- if is_controlnet_enabled(controlnet_args):
- process_with_controlnet(p, args, anim_args, controlnet_args, root, parseq_adapter, is_img2img=True, frame_idx=frame)
+ if is_controlnet_enabled(controlnet_args) or is_animatediff_enabled(animatediff_args):
+ process_with_controlnet(p, args, anim_args, controlnet_args, animatediff_args, root, parseq_adapter, is_img2img=True, frame_idx=frame)
with A1111OptionsOverrider({"control_net_detectedmap_dir" : os.path.join(args.outdir, "controlnet_detected_map")}):
processed = processing.process_images(p)
@@ -287,9 +288,12 @@ def generate_inner(args, keys, anim_args, loop_args, controlnet_args, root, pars
root.initial_info = processed.info
if root.first_frame is None:
- root.first_frame = processed.images[0]
+ root.first_frame = processed.images[-1]
- results = processed.images[0]
+ results = processed.images[-1] # AD uses ascending order, so we need to get the last frame
+
+ if len(processed.images) > 1:
+ reap_animatediff(processed.images, animatediff_args, args, root, frame)
return results
diff --git a/scripts/deforum_helpers/render.py b/scripts/deforum_helpers/render.py
index 964498c3a..247a5ef80 100644
--- a/scripts/deforum_helpers/render.py
+++ b/scripts/deforum_helpers/render.py
@@ -52,11 +52,15 @@
from deforum_api import JobStatusTracker
-def render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root):
+def render_animation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root):
# initialise Parseq adapter
+ # TODO: @rewbs
parseq_adapter = ParseqAdapter(parseq_args, anim_args, video_args, controlnet_args, loop_args)
+ if animatediff_args.animatediff_enabled:
+ print("*Rendering with AnimateDiff turned on. (Experimental!)*")
+
if opts.data.get("deforum_save_gen_info_as_srt", False): # create .srt file and set timeframe mechanism using FPS
srt_filename = os.path.join(args.outdir, f"{root.timestring}.srt")
srt_frame_duration = init_srt_file(srt_filename, video_args.fps)
@@ -93,7 +97,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, contro
print(f"Saving animation frames to:\n{args.outdir}")
# save settings.txt file for the current run
- save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, video_args, root)
+ save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, animatediff_args, video_args, root)
# resume from timestring
if anim_args.resume_from_timestring:
@@ -547,7 +551,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, contro
args.seed = random.randint(0, 2 ** 32 - 1)
print(f"Optical flow redo is diffusing and warping using {optical_flow_redo_generation} and seed {args.seed} optical flow before generation.")
- disposable_image = generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name)
+ disposable_image = generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name)
disposable_image = cv2.cvtColor(np.array(disposable_image), cv2.COLOR_RGB2BGR)
disposable_flow = get_flow_from_images(prev_img, disposable_image, optical_flow_redo_generation, raft_model)
disposable_image = cv2.cvtColor(disposable_image, cv2.COLOR_BGR2RGB)
@@ -563,7 +567,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, contro
for n in range(0, int(anim_args.diffusion_redo)):
print(f"Redo generation {n + 1} of {int(anim_args.diffusion_redo)} before final generation")
args.seed = random.randint(0, 2 ** 32 - 1)
- disposable_image = generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name)
+ disposable_image = generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name)
disposable_image = cv2.cvtColor(np.array(disposable_image), cv2.COLOR_RGB2BGR)
# color match on last one only
if n == int(anim_args.diffusion_redo):
@@ -574,7 +578,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, contro
gc.collect()
# generation
- image = generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name)
+ image = generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name)
if image is None:
break
diff --git a/scripts/deforum_helpers/render_modes.py b/scripts/deforum_helpers/render_modes.py
index 46730c991..388ecfa76 100644
--- a/scripts/deforum_helpers/render_modes.py
+++ b/scripts/deforum_helpers/render_modes.py
@@ -30,7 +30,7 @@
from .save_images import save_image
from .settings import save_settings_from_animation_run
-def render_input_video(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root):
+def render_input_video(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root):
# create a folder for the video input frames to live in
video_in_frame_path = os.path.join(args.outdir, 'inputframes')
os.makedirs(video_in_frame_path, exist_ok=True)
@@ -61,10 +61,10 @@ def render_input_video(args, anim_args, video_args, parseq_args, loop_args, cont
args.use_mask = True
args.overlay_mask = True
- render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root)
+ render_animation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root)
# Modified a copy of the above to allow using masking video with out a init video.
-def render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root):
+def render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root):
# create a folder for the video input frames to live in
mask_in_frame_path = os.path.join(args.outdir, 'maskframes')
os.makedirs(mask_in_frame_path, exist_ok=True)
@@ -80,7 +80,7 @@ def render_animation_with_video_mask(args, anim_args, video_args, parseq_args, l
#args.use_init = True
print(f"Loading {anim_args.max_frames} input frames from {mask_in_frame_path} and saving video frames to {args.outdir}")
- render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root)
+ render_animation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root)
def get_parsed_value(value, frame_idx, max_f):
pattern = r'`.*?`'
@@ -93,7 +93,7 @@ def get_parsed_value(value, frame_idx, max_f):
parsed_value = parsed_value.replace(matched_string, str(value))
return parsed_value
-def render_interpolation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root):
+def render_interpolation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root):
# use parseq if manifest is provided
parseq_adapter = ParseqAdapter(parseq_args, anim_args, video_args, controlnet_args, loop_args)
@@ -106,7 +106,7 @@ def render_interpolation(args, anim_args, video_args, parseq_args, loop_args, co
print(f"Saving interpolation animation frames to {args.outdir}")
# save settings.txt file for the current run
- save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, video_args, root)
+ save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, animatediff_args, video_args, root)
# Compute interpolated prompts
if parseq_adapter.manages_prompts():
@@ -162,7 +162,7 @@ def render_interpolation(args, anim_args, video_args, parseq_args, loop_args, co
args.seed = int(keys.seed_schedule_series[frame_idx]) if (args.seed_behavior == 'schedule' or parseq_adapter.manages_seed()) else args.seed
opts.data["CLIP_stop_at_last_layers"] = scheduled_clipskip if scheduled_clipskip is not None else opts.data["CLIP_stop_at_last_layers"]
- image = generate(args, keys, anim_args, loop_args, controlnet_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name)
+ image = generate(args, keys, anim_args, loop_args, controlnet_args, animatediff_args, root, parseq_adapter, frame_idx, sampler_name=scheduled_sampler_name)
filename = f"{root.timestring}_{frame_idx:09}.png"
save_image(image, 'PIL', filename, args, video_args, root)
diff --git a/scripts/deforum_helpers/run_deforum.py b/scripts/deforum_helpers/run_deforum.py
index ca25d152b..5065c81eb 100644
--- a/scripts/deforum_helpers/run_deforum.py
+++ b/scripts/deforum_helpers/run_deforum.py
@@ -71,7 +71,7 @@ def run_deforum(*args):
args_dict['self'] = None
args_dict['p'] = p
try:
- args_loaded_ok, root, args, anim_args, video_args, parseq_args, loop_args, controlnet_args = process_args(args_dict, i)
+ args_loaded_ok, root, args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args = process_args(args_dict, i)
except Exception as e:
JobStatusTracker().fail_job(job_id, error_type="TERMINAL", message="Invalid arguments.")
print("\n*START OF TRACEBACK*")
@@ -111,13 +111,13 @@ def run_deforum(*args):
JobStatusTracker().update_output_info(job_id, outdir=args.outdir, timestring=root.timestring)
if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D':
if anim_args.use_mask_video:
- render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root) # allow mask video without an input video
+ render_animation_with_video_mask(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root) # allow mask video without an input video
else:
- render_animation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root)
+ render_animation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root)
elif anim_args.animation_mode == 'Video Input':
- render_input_video(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root)#TODO: prettify code
+ render_input_video(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root)#TODO: prettify code
elif anim_args.animation_mode == 'Interpolation':
- render_interpolation(args, anim_args, video_args, parseq_args, loop_args, controlnet_args, root)
+ render_interpolation(args, anim_args, video_args, parseq_args, loop_args, animatediff_args, controlnet_args, root)
else:
print('Other modes are not available yet!')
except Exception as e:
@@ -221,7 +221,7 @@ def run_deforum(*args):
if shared.opts.data.get("deforum_enable_persistent_settings", False):
persistent_sett_path = shared.opts.data.get("deforum_persistent_settings_path")
- save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, video_args, root, persistent_sett_path)
+ save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, animatediff_args, controlnet_args, video_args, root, persistent_sett_path)
# Close the pipeline, not to interfere with ControlNet
try:
diff --git a/scripts/deforum_helpers/settings.py b/scripts/deforum_helpers/settings.py
index bd1132d9b..214f5821f 100644
--- a/scripts/deforum_helpers/settings.py
+++ b/scripts/deforum_helpers/settings.py
@@ -58,7 +58,7 @@ def load_args(args_dict_main, args, anim_args, parseq_args, loop_args, controlne
return True
# save settings function that get calls when run_deforum is being called
-def save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, video_args, root, full_out_file_path = None):
+def save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, controlnet_args, animatediff_args, video_args, root, full_out_file_path = None):
if full_out_file_path:
args.__dict__["seed"] = root.raw_seed
args.__dict__["batch_name"] = root.raw_batch_name
@@ -69,7 +69,7 @@ def save_settings_from_animation_run(args, anim_args, parseq_args, loop_args, co
settings_filename = full_out_file_path if full_out_file_path else os.path.join(args.outdir, f"{root.timestring}_settings.txt")
with open(settings_filename, "w+", encoding="utf-8") as f:
s = {}
- for d in (args.__dict__, anim_args.__dict__, parseq_args.__dict__, loop_args.__dict__, controlnet_args.__dict__, video_args.__dict__):
+ for d in (args.__dict__, anim_args.__dict__, parseq_args.__dict__, loop_args.__dict__, controlnet_args.__dict__, animatediff_args.__dict__, video_args.__dict__):
s.update({k: v for k, v in d.items() if k not in exclude_keys})
s["sd_model_name"] = sh.sd_model.sd_checkpoint_info.name
s["sd_model_hash"] = sh.sd_model.sd_checkpoint_info.hash
diff --git a/scripts/deforum_helpers/ui_left.py b/scripts/deforum_helpers/ui_left.py
index aa1886372..57b4df03a 100644
--- a/scripts/deforum_helpers/ui_left.py
+++ b/scripts/deforum_helpers/ui_left.py
@@ -20,6 +20,7 @@
from .gradio_funcs import change_css, handle_change_functions
from .args import DeforumArgs, DeforumAnimArgs, ParseqArgs, DeforumOutputArgs, RootArgs, LoopArgs
from .deforum_controlnet import setup_controlnet_ui
+from .deforum_animatediff import setup_animatediff_ui
from .ui_elements import get_tab_run, get_tab_keyframes, get_tab_prompts, get_tab_init, get_tab_hybrid, get_tab_output
def set_arg_lists():
@@ -47,11 +48,12 @@ def setup_deforum_left_side_ui():
tab_keyframes_params = get_tab_keyframes(d, da, dloopArgs) # Keyframes tab
tab_prompts_params = get_tab_prompts(da) # Prompts tab
tab_init_params = get_tab_init(d, da, dp) # Init tab
+ animatediff_dict = setup_animatediff_ui() # AnimateDiff tab
controlnet_dict = setup_controlnet_ui() # ControlNet tab
tab_hybrid_params = get_tab_hybrid(da) # Hybrid tab
tab_output_params = get_tab_output(da, dv) # Output tab
# add returned gradio elements from main tabs to locals()
- for key, value in {**tab_run_params, **tab_keyframes_params, **tab_prompts_params, **tab_init_params, **controlnet_dict, **tab_hybrid_params, **tab_output_params}.items():
+ for key, value in {**tab_run_params, **tab_keyframes_params, **tab_prompts_params, **tab_init_params, **animatediff_dict, **controlnet_dict, **tab_hybrid_params, **tab_output_params}.items():
locals()[key] = value
# Gradio's Change functions - hiding and renaming elements based on other elements