Skip to content

Commit

Permalink
Implement support for block_lr network argument
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Aug 15, 2023
1 parent 9b3f7cc commit 859f71b
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def save_configuration(
up_lr_weight,
block_lr_zero_threshold,
block_dims,
block_alphas,
block_alphas,block_lr,
conv_block_dims,
conv_block_alphas,
weighted_captions,
Expand Down Expand Up @@ -292,7 +292,7 @@ def open_configuration(
up_lr_weight,
block_lr_zero_threshold,
block_dims,
block_alphas,
block_alphas,block_lr,
conv_block_dims,
conv_block_alphas,
weighted_captions,
Expand Down Expand Up @@ -463,7 +463,7 @@ def train_model(
up_lr_weight,
block_lr_zero_threshold,
block_dims,
block_alphas,
block_alphas,block_lr,
conv_block_dims,
conv_block_alphas,
weighted_captions,
Expand Down Expand Up @@ -788,6 +788,7 @@ def train_model(
'block_lr_zero_threshold',
'block_dims',
'block_alphas',
'block_lr',
'conv_block_dims',
'conv_block_alphas',
'rank_dropout',
Expand Down Expand Up @@ -822,6 +823,7 @@ def train_model(
'block_lr_zero_threshold',
'block_dims',
'block_alphas',
'block_lr',
'conv_block_dims',
'conv_block_alphas',
'rank_dropout',
Expand All @@ -844,6 +846,9 @@ def train_model(

if network_args:
run_cmd += f' --network_args{network_args}'

# if not block_lr == '':
# run_cmd += f' --block_lr="{block_lr}"'

if not (float(text_encoder_lr) == 0) or not (float(unet_lr) == 0):
if not (float(text_encoder_lr) == 0) and not (float(unet_lr) == 0):
Expand Down Expand Up @@ -1369,6 +1374,11 @@ def update_LoRA_settings(LoRA_type):
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
info='Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.',
)
block_lr = gr.Textbox(
label='Block LR',
placeholder='(Optional)',
info='Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3',
)
with gr.Tab(label='Conv'):
with gr.Row(visible=True):
conv_block_dims = gr.Textbox(
Expand Down Expand Up @@ -1540,6 +1550,7 @@ def update_LoRA_settings(LoRA_type):
block_lr_zero_threshold,
block_dims,
block_alphas,
block_lr,
conv_block_dims,
conv_block_alphas,
advanced_training.weighted_captions,
Expand Down

0 comments on commit 859f71b

Please sign in to comment.