Skip to content

Commit

Permalink
Update sd-scripts and add support for t5xxl LR
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Sep 11, 2024
1 parent 63c1e48 commit f365b63
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
32 changes: 29 additions & 3 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def save_configuration(

###
text_encoder_lr,
t5xxl_lr,
unet_lr,
network_dim,
network_weights,
Expand Down Expand Up @@ -453,6 +454,7 @@ def open_configuration(

###
text_encoder_lr,
t5xxl_lr,
unet_lr,
network_dim,
network_weights,
Expand Down Expand Up @@ -734,6 +736,7 @@ def train_model(

###
text_encoder_lr,
t5xxl_lr,
unet_lr,
network_dim,
network_weights,
Expand Down Expand Up @@ -1249,6 +1252,20 @@ def train_model(
if value:
network_args += f" {key}={value}"

# Set the text_encoder_lr to multiple values if both text_encoder_lr and t5xxl_lr are set
if text_encoder_lr == 0 and t5xxl_lr > 0:
log.error("When specifying T5XXL learning rate, text encoder learning rate need to be a value greater than 0.")
return TRAIN_BUTTON_VISIBLE

text_encoder_lr_str = ""

if text_encoder_lr > 0 and t5xxl_lr > 0:
# Set the text_encoder_lr to a combination of text_encoder_lr and t5xxl_lr
text_encoder_lr_str = f"{text_encoder_lr} {t5xxl_lr}"
elif text_encoder_lr > 0:
# Set the text_encoder_lr to text_encoder_lr only
text_encoder_lr_str = f"{text_encoder_lr}"

# Convert learning rates to float once and store the result for re-use
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
text_encoder_lr_float = (
Expand Down Expand Up @@ -1427,7 +1444,7 @@ def train_model(
"stop_text_encoder_training": (
stop_text_encoder_training if stop_text_encoder_training != 0 else None
),
"text_encoder_lr": text_encoder_lr if not 0 else None,
"text_encoder_lr": text_encoder_lr_str if not 0 else None,
"train_batch_size": train_batch_size,
"train_data_dir": train_data_dir,
"training_comment": training_comment,
Expand Down Expand Up @@ -1682,8 +1699,16 @@ def list_presets(path):
with gr.Row():
text_encoder_lr = gr.Number(
label="Text Encoder learning rate",
value=0.0001,
info="(Optional)",
value=0,
info="(Optional) Set CLIP-L and T5XXL learning rates.",
minimum=0,
maximum=1,
)

t5xxl_lr = gr.Number(
label="T5XXL learning rate",
value=0,
info="(Optional) Override the T5XXL learning rate set by the Text Encoder learning rate if you desire a different one.",
minimum=0,
maximum=1,
)
Expand Down Expand Up @@ -2558,6 +2583,7 @@ def update_LoRA_settings(
sdxl_params.sdxl_cache_text_encoder_outputs,
sdxl_params.sdxl_no_half_vae,
text_encoder_lr,
t5xxl_lr,
unet_lr,
network_dim,
network_weights,
Expand Down

0 comments on commit f365b63

Please sign in to comment.