Skip to content

Commit a6f0ff7

Browse files
committed
Fix issue with max_train_steps
1 parent d47e3e6 commit a6f0ff7

File tree

4 files changed

+33
-54
lines changed

4 files changed

+33
-54
lines changed

kohya_gui/dreambooth_gui.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -781,23 +781,23 @@ def train_model(
781781

782782
log.info(f"Regularization factor: {reg_factor}")
783783

784-
if max_train_steps == 0:
785-
# calculate max_train_steps
786-
max_train_steps = int(
787-
math.ceil(
788-
float(total_steps)
789-
/ int(train_batch_size)
790-
/ int(gradient_accumulation_steps)
791-
* int(epoch)
792-
* int(reg_factor)
793-
)
794-
)
795-
max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
796-
else:
797-
if max_train_steps == 0:
798-
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
799-
else:
800-
max_train_steps_info = f"Max train steps: {max_train_steps}"
784+
# if max_train_steps == 0:
785+
# # calculate max_train_steps
786+
# max_train_steps = int(
787+
# math.ceil(
788+
# float(total_steps)
789+
# / int(train_batch_size)
790+
# / int(gradient_accumulation_steps)
791+
# * int(epoch)
792+
# * int(reg_factor)
793+
# )
794+
# )
795+
# max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
796+
# else:
797+
# if max_train_steps == 0:
798+
# max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
799+
# else:
800+
# max_train_steps_info = f"Max train steps: {max_train_steps}"
801801

802802
log.info(f"Total steps: {total_steps}")
803803

kohya_gui/finetune_gui.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -846,16 +846,16 @@ def train_model(
846846
repeats = int(image_num) * int(dataset_repeats)
847847
log.info(f"repeats = {str(repeats)}")
848848

849-
if max_train_steps == 0:
850-
# calculate max_train_steps
851-
max_train_steps = int(
852-
math.ceil(
853-
float(repeats)
854-
/ int(train_batch_size)
855-
/ int(gradient_accumulation_steps)
856-
* int(epoch)
857-
)
858-
)
849+
# if max_train_steps == 0:
850+
# # calculate max_train_steps
851+
# max_train_steps = int(
852+
# math.ceil(
853+
# float(repeats)
854+
# / int(train_batch_size)
855+
# / int(gradient_accumulation_steps)
856+
# * int(epoch)
857+
# )
858+
# )
859859

860860
# Divide by two because flip augmentation create two copied of the source images
861861
if flip_aug and max_train_steps:

kohya_gui/lora_gui.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ def train_model(
10761076

10771077
log.info(f"Regularization factor: {reg_factor}")
10781078

1079-
if max_train_steps == 0:
1079+
if (max_train_steps == 0) and (stop_text_encoder_training != 0):
10801080
# calculate max_train_steps
10811081
max_train_steps = int(
10821082
math.ceil(
@@ -1094,13 +1094,9 @@ def train_model(
10941094
else:
10951095
max_train_steps_info = f"Max train steps: {max_train_steps}"
10961096

1097-
# calculate stop encoder training
1098-
if stop_text_encoder_training == 0:
1099-
stop_text_encoder_training = 0
1100-
else:
1101-
stop_text_encoder_training = math.ceil(
1102-
float(max_train_steps) / 100 * int(stop_text_encoder_training)
1103-
)
1097+
stop_text_encoder_training = math.ceil(
1098+
float(max_train_steps) / 100 * int(stop_text_encoder_training)
1099+
) if stop_text_encoder_training != 0 else 0
11041100

11051101
# Calculate lr_warmup_steps
11061102
if lr_warmup_steps > 0:

kohya_gui/textual_inversion_gui.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -664,22 +664,9 @@ def train_model(
664664
log.info(f"Regularization factor: {reg_factor}")
665665

666666
if max_train_steps == 0:
667-
# calculate max_train_steps
668-
max_train_steps = int(
669-
math.ceil(
670-
float(total_steps)
671-
/ int(train_batch_size)
672-
/ int(gradient_accumulation_steps)
673-
* int(epoch)
674-
* int(reg_factor)
675-
)
676-
)
677-
max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
667+
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
678668
else:
679-
if max_train_steps == 0:
680-
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
681-
else:
682-
max_train_steps_info = f"Max train steps: {max_train_steps}"
669+
max_train_steps_info = f"Max train steps: {max_train_steps}"
683670

684671
# calculate stop encoder training
685672
if stop_text_encoder_training_pct == 0:
@@ -1076,10 +1063,6 @@ def list_embedding_files(path):
10761063
step=1,
10771064
label="Vectors",
10781065
)
1079-
# max_train_steps = gr.Textbox(
1080-
# label='Max train steps',
1081-
# placeholder='(Optional) Maximum number of steps',
1082-
# )
10831066
template = gr.Dropdown(
10841067
label="Template",
10851068
choices=[

0 commit comments

Comments
 (0)