Skip to content

Commit d0cd9f5

Browse files
committed
Reverting all changes for max_train_steps
1 parent 309a9bb commit d0cd9f5

File tree

4 files changed

+58
-45
lines changed

4 files changed

+58
-45
lines changed

kohya_gui/dreambooth_gui.py

+18-20
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,6 @@ def train_model(
654654
gr.Button(visible=False or headless),
655655
gr.Textbox(value=train_state_value),
656656
]
657-
658-
max_train_steps_info = "Automatic by sd-scripts"
659657

660658
if executor.is_running():
661659
log.error("Training is already running. Can't start another training session.")
@@ -783,23 +781,23 @@ def train_model(
783781

784782
log.info(f"Regularization factor: {reg_factor}")
785783

786-
# if max_train_steps == 0:
787-
# # calculate max_train_steps
788-
# max_train_steps = int(
789-
# math.ceil(
790-
# float(total_steps)
791-
# / int(train_batch_size)
792-
# / int(gradient_accumulation_steps)
793-
# * int(epoch)
794-
# * int(reg_factor)
795-
# )
796-
# )
797-
# max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
798-
# else:
799-
# if max_train_steps == 0:
800-
# max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
801-
# else:
802-
# max_train_steps_info = f"Max train steps: {max_train_steps}"
784+
if max_train_steps == 0:
785+
# calculate max_train_steps
786+
max_train_steps = int(
787+
math.ceil(
788+
float(total_steps)
789+
/ int(train_batch_size)
790+
/ int(gradient_accumulation_steps)
791+
* int(epoch)
792+
* int(reg_factor)
793+
)
794+
)
795+
max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
796+
else:
797+
if max_train_steps == 0:
798+
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
799+
else:
800+
max_train_steps_info = f"Max train steps: {max_train_steps}"
803801

804802
log.info(f"Total steps: {total_steps}")
805803

@@ -1470,4 +1468,4 @@ def dreambooth_tab(
14701468
folders.reg_data_dir,
14711469
folders.output_dir,
14721470
folders.logging_dir,
1473-
)
1471+
)

kohya_gui/finetune_gui.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,6 @@ def train_model(
685685
gr.Button(visible=False or headless),
686686
gr.Textbox(value=train_state_value),
687687
]
688-
689-
max_train_steps_info = "Automatic by sd-scripts"
690688

691689
if executor.is_running():
692690
log.error("Training is already running. Can't start another training session.")
@@ -848,16 +846,16 @@ def train_model(
848846
repeats = int(image_num) * int(dataset_repeats)
849847
log.info(f"repeats = {str(repeats)}")
850848

851-
# if max_train_steps == 0:
852-
# # calculate max_train_steps
853-
# max_train_steps = int(
854-
# math.ceil(
855-
# float(repeats)
856-
# / int(train_batch_size)
857-
# / int(gradient_accumulation_steps)
858-
# * int(epoch)
859-
# )
860-
# )
849+
if max_train_steps == 0:
850+
# calculate max_train_steps
851+
max_train_steps = int(
852+
math.ceil(
853+
float(repeats)
854+
/ int(train_batch_size)
855+
/ int(gradient_accumulation_steps)
856+
* int(epoch)
857+
)
858+
)
861859

862860
# Divide by two because flip augmentation create two copied of the source images
863861
if flip_aug and max_train_steps:
@@ -1634,4 +1632,4 @@ def list_presets(path):
16341632
if os.path.exists(top_level_path):
16351633
with open(os.path.join(top_level_path), "r", encoding="utf-8") as file:
16361634
guides_top_level = file.read() + "\n"
1637-
gr.Markdown(guides_top_level)
1635+
gr.Markdown(guides_top_level)

kohya_gui/lora_gui.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -883,8 +883,6 @@ def train_model(
883883
gr.Button(visible=False or headless),
884884
gr.Textbox(value=train_state_value),
885885
]
886-
887-
max_train_steps_info = "Automatic by sd-scripts"
888886

889887
if executor.is_running():
890888
log.error("Training is already running. Can't start another training session.")
@@ -1078,7 +1076,7 @@ def train_model(
10781076

10791077
log.info(f"Regularization factor: {reg_factor}")
10801078

1081-
if (max_train_steps == 0) and (stop_text_encoder_training != 0):
1079+
if max_train_steps == 0:
10821080
# calculate max_train_steps
10831081
max_train_steps = int(
10841082
math.ceil(
@@ -1096,9 +1094,13 @@ def train_model(
10961094
else:
10971095
max_train_steps_info = f"Max train steps: {max_train_steps}"
10981096

1099-
stop_text_encoder_training = math.ceil(
1100-
float(max_train_steps) / 100 * int(stop_text_encoder_training)
1101-
) if stop_text_encoder_training != 0 else 0
1097+
# calculate stop encoder training
1098+
if stop_text_encoder_training == 0:
1099+
stop_text_encoder_training = 0
1100+
else:
1101+
stop_text_encoder_training = math.ceil(
1102+
float(max_train_steps) / 100 * int(stop_text_encoder_training)
1103+
)
11021104

11031105
# Calculate lr_warmup_steps
11041106
if lr_warmup_steps > 0:
@@ -2855,4 +2857,4 @@ def update_LoRA_settings(
28552857
folders.reg_data_dir,
28562858
folders.output_dir,
28572859
folders.logging_dir,
2858-
)
2860+
)

kohya_gui/textual_inversion_gui.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,6 @@ def train_model(
521521
gr.Textbox(value=train_state_value),
522522
]
523523

524-
max_train_steps_info = "Automatic by sd-scripts"
525-
526524
if executor.is_running():
527525
log.error("Training is already running. Can't start another training session.")
528526
return TRAIN_BUTTON_VISIBLE
@@ -666,9 +664,22 @@ def train_model(
666664
log.info(f"Regularization factor: {reg_factor}")
667665

668666
if max_train_steps == 0:
669-
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
667+
# calculate max_train_steps
668+
max_train_steps = int(
669+
math.ceil(
670+
float(total_steps)
671+
/ int(train_batch_size)
672+
/ int(gradient_accumulation_steps)
673+
* int(epoch)
674+
* int(reg_factor)
675+
)
676+
)
677+
max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
670678
else:
671-
max_train_steps_info = f"Max train steps: {max_train_steps}"
679+
if max_train_steps == 0:
680+
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
681+
else:
682+
max_train_steps_info = f"Max train steps: {max_train_steps}"
672683

673684
# calculate stop encoder training
674685
if stop_text_encoder_training_pct == 0:
@@ -1065,6 +1076,10 @@ def list_embedding_files(path):
10651076
step=1,
10661077
label="Vectors",
10671078
)
1079+
# max_train_steps = gr.Textbox(
1080+
# label='Max train steps',
1081+
# placeholder='(Optional) Maximum number of steps',
1082+
# )
10681083
template = gr.Dropdown(
10691084
label="Template",
10701085
choices=[
@@ -1294,4 +1309,4 @@ def list_embedding_files(path):
12941309
folders.reg_data_dir,
12951310
folders.output_dir,
12961311
folders.logging_dir,
1297-
)
1312+
)

0 commit comments

Comments
 (0)