Skip to content

Commit d63a7fa

Browse files
Refactor: Clean up LR handling logic in LoRA GUI
This commit refactors the learning rate (LR) handling in `kohya_gui/lora_gui.py` for LoRA training. The previous fix for LR misinterpretation involved commenting out a line. This commit completes the cleanup by: - Removing the `do_not_set_learning_rate` variable and its associated conditional logic, which became redundant. - Renaming the float-converted `learning_rate` to `learning_rate_float` for clarity. - Ensuring that `learning_rate_float` and the float-converted `unet_lr_float` are consistently used when preparing the `config_toml_data` for the training script. This makes the code cleaner and the intent of always passing the main learning rate (along with specific TE/UNet LRs) more direct. The functional behavior of the LR fix remains the same.
1 parent 3a8b599 commit d63a7fa

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

kohya_gui/lora_gui.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,27 +1421,22 @@ def train_model(
14211421
text_encoder_lr_list = [float(text_encoder_lr), float(text_encoder_lr)]
14221422

14231423
# Convert learning rates to float once and store the result for re-use
1424-
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
1424+
learning_rate_float = float(learning_rate) if learning_rate is not None else 0.0
14251425
text_encoder_lr_float = (
14261426
float(text_encoder_lr) if text_encoder_lr is not None else 0.0
14271427
)
14281428
unet_lr_float = float(unet_lr) if unet_lr is not None else 0.0
14291429

14301430
# Determine the training configuration based on learning rate values
14311431
# Sets flags for training specific components based on the provided learning rates.
1432-
if float(learning_rate) == unet_lr_float == text_encoder_lr_float == 0:
1432+
if learning_rate_float == unet_lr_float == text_encoder_lr_float == 0:
14331433
output_message(msg="Please input learning rate values.", headless=headless)
14341434
return TRAIN_BUTTON_VISIBLE
14351435
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
14361436
network_train_text_encoder_only = text_encoder_lr_float != 0 and unet_lr_float == 0
14371437
# Flag to train unet only if its learning rate is non-zero and text encoder's is zero.
14381438
network_train_unet_only = text_encoder_lr_float == 0 and unet_lr_float != 0
14391439

1440-
do_not_set_learning_rate = False # Initialize with a default value
1441-
if text_encoder_lr_float != 0 or unet_lr_float != 0:
1442-
log.info("Learning rate won't be used for training because text_encoder_lr or unet_lr is set.")
1443-
# do_not_set_learning_rate = True # This line is now commented out
1444-
14451440
clip_l_value = None
14461441
if sd3_checkbox:
14471442
# print("Setting clip_l_value to sd3_clip_l")
@@ -1519,7 +1514,7 @@ def train_model(
15191514
"ip_noise_gamma": ip_noise_gamma if ip_noise_gamma != 0 else None,
15201515
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
15211516
"keep_tokens": int(keep_tokens),
1522-
"learning_rate": None if do_not_set_learning_rate else learning_rate,
1517+
"learning_rate": learning_rate_float,
15231518
"logging_dir": logging_dir,
15241519
"log_config": log_config,
15251520
"log_tracker_name": log_tracker_name,
@@ -1640,7 +1635,7 @@ def train_model(
16401635
"train_batch_size": train_batch_size,
16411636
"train_data_dir": train_data_dir,
16421637
"training_comment": training_comment,
1643-
"unet_lr": unet_lr if unet_lr != 0 else None,
1638+
"unet_lr": unet_lr_float if unet_lr_float != 0.0 else None,
16441639
"log_with": log_with,
16451640
"v2": v2,
16461641
"v_parameterization": v_parameterization,

0 commit comments

Comments
 (0)