Skip to content

Commit

Permalink
Apply blue to code
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Aug 15, 2023
1 parent d311bba commit c19bdc3
Show file tree
Hide file tree
Showing 24 changed files with 878 additions and 456 deletions.
88 changes: 55 additions & 33 deletions dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
output_message,
verify_image_folder_pattern,
SaveConfigFile,
save_to_file
save_to_file,
)
from library.class_configuration_file import ConfigurationFile
from library.class_source_model import SourceModel
Expand Down Expand Up @@ -99,7 +99,8 @@ def save_configuration(
vae,
output_name,
max_token_length,
max_train_epochs,max_train_steps,
max_train_epochs,
max_train_steps,
max_data_loader_n_workers,
mem_eff_attn,
gradient_accumulation_steps,
Expand All @@ -110,11 +111,13 @@ def save_configuration(
persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps, v_pred_like_loss,
bucket_reso_steps,
v_pred_like_loss,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args, lr_scheduler_args,
optimizer_args,
lr_scheduler_args,
noise_offset_type,
noise_offset,
adaptive_noise_scale,
Expand Down Expand Up @@ -162,7 +165,11 @@ def save_configuration(
if not os.path.exists(destination_directory):
os.makedirs(destination_directory)

SaveConfigFile(parameters=parameters, file_path=file_path, exclusion=['file_path', 'save_as'])
SaveConfigFile(
parameters=parameters,
file_path=file_path,
exclusion=['file_path', 'save_as'],
)

return file_path

Expand Down Expand Up @@ -213,7 +220,8 @@ def open_configuration(
vae,
output_name,
max_token_length,
max_train_epochs,max_train_steps,
max_train_epochs,
max_train_steps,
max_data_loader_n_workers,
mem_eff_attn,
gradient_accumulation_steps,
Expand All @@ -224,11 +232,13 @@ def open_configuration(
persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps, v_pred_like_loss,
bucket_reso_steps,
v_pred_like_loss,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args, lr_scheduler_args,
optimizer_args,
lr_scheduler_args,
noise_offset_type,
noise_offset,
adaptive_noise_scale,
Expand Down Expand Up @@ -326,7 +336,8 @@ def train_model(
vae,
output_name,
max_token_length,
max_train_epochs,max_train_steps,
max_train_epochs,
max_train_steps,
max_data_loader_n_workers,
mem_eff_attn,
gradient_accumulation_steps,
Expand All @@ -337,11 +348,13 @@ def train_model(
persistent_data_loader_workers,
bucket_no_upscale,
random_crop,
bucket_reso_steps, v_pred_like_loss,
bucket_reso_steps,
v_pred_like_loss,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args, lr_scheduler_args,
optimizer_args,
lr_scheduler_args,
noise_offset_type,
noise_offset,
adaptive_noise_scale,
Expand All @@ -366,7 +379,7 @@ def train_model(
):
# Get list of function parameters and values
parameters = list(locals().items())

print_only_bool = True if print_only.get('label') == 'True' else False
log.info(f'Start training Dreambooth...')

Expand Down Expand Up @@ -534,7 +547,7 @@ def train_model(
run_cmd += f' "./sdxl_train.py"'
else:
run_cmd += f' "./train_db.py"'

if v2:
run_cmd += ' --v2'
if v_parameterization:
Expand Down Expand Up @@ -659,22 +672,28 @@ def train_model(
'Here is the trainer command as a reference. It will not be executed:\n'
)
print(run_cmd)

save_to_file(run_cmd)
else:
# Saving config file for model
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
file_path = os.path.join(output_dir, f'{output_name}_{formatted_datetime}.json')

formatted_datetime = current_datetime.strftime('%Y%m%d-%H%M%S')
file_path = os.path.join(
output_dir, f'{output_name}_{formatted_datetime}.json'
)

log.info(f'Saving training config to {file_path}...')

SaveConfigFile(parameters=parameters, file_path=file_path, exclusion=['file_path', 'save_as', 'headless', 'print_only'])

SaveConfigFile(
parameters=parameters,
file_path=file_path,
exclusion=['file_path', 'save_as', 'headless', 'print_only'],
)

log.info(run_cmd)

# Run the command

executor.execute_command(run_cmd=run_cmd)

# check if output_dir/last is a folder... therefore it is a diffuser model
Expand All @@ -697,13 +716,15 @@ def dreambooth_tab(
dummy_db_true = gr.Label(value=True, visible=False)
dummy_db_false = gr.Label(value=False, visible=False)
dummy_headless = gr.Label(value=headless, visible=False)

with gr.Tab('Training'):
gr.Markdown('Train a custom model using kohya dreambooth python code...')

gr.Markdown(
'Train a custom model using kohya dreambooth python code...'
)

# Setup Configuration Files Gradio
config = ConfigurationFile(headless)

source_model = SourceModel(headless=headless)

with gr.Tab('Folders'):
Expand All @@ -715,18 +736,18 @@ def dreambooth_tab(
lr_scheduler_value='cosine',
lr_warmup_value='10',
)

# # Add SDXL Parameters
# sdxl_params = SDXLParameters(source_model.sdxl_checkbox, show_sdxl_cache_text_encoder_outputs=False)

with gr.Tab('Advanced', elem_id='advanced_tab'):
advanced_training = AdvancedTraining(headless=headless)
advanced_training.color_aug.change(
color_aug_changed,
inputs=[advanced_training.color_aug],
outputs=[basic_training.cache_latents],
)

with gr.Tab('Samples', elem_id='samples_tab'):
sample = SampleImages()

Expand All @@ -745,13 +766,16 @@ def dreambooth_tab(

with gr.Row():
button_run = gr.Button('Start training', variant='primary')

button_stop_training = gr.Button('Stop training')

button_print = gr.Button('Print training command')

# Setup gradio tensorboard buttons
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
(
button_start_tensorboard,
button_stop_tensorboard,
) = gradio_tensorboard()

button_start_tensorboard.click(
start_tensorboard,
Expand Down Expand Up @@ -882,10 +906,8 @@ def dreambooth_tab(
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
show_progress=False,
)

button_stop_training.click(
executor.kill_command
)

button_stop_training.click(executor.kill_command)

button_print.click(
train_model,
Expand Down
Loading

0 comments on commit c19bdc3

Please sign in to comment.