Skip to content

Commit dfbe972

Browse files
committed
[training] fixes to the quantization training script and add AdEMAMix optimizer as an option (#9806)
* fixes * more fixes.
1 parent bbbd1c0 commit dfbe972

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -349,14 +349,19 @@ def parse_args(input_args=None):
349349
"--optimizer",
350350
type=str,
351351
default="AdamW",
352-
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
352+
choices=["AdamW", "Prodigy", "AdEMAMix"],
353353
)
354354

355355
parser.add_argument(
356356
"--use_8bit_adam",
357357
action="store_true",
358358
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
359359
)
360+
parser.add_argument(
361+
"--use_8bit_ademamix",
362+
action="store_true",
363+
help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
364+
)
360365

361366
parser.add_argument(
362367
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
@@ -820,16 +825,15 @@ def load_model_hook(models, input_dir):
820825
params_to_optimize = [transformer_parameters_with_lr]
821826

822827
# Optimizer creation
823-
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
828+
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
824829
logger.warning(
825-
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
826-
"Defaulting to adamW"
830+
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
831+
f"set to {args.optimizer.lower()}"
827832
)
828-
args.optimizer = "adamw"
829833

830-
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
834+
if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
831835
logger.warning(
832-
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
836+
f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
833837
f"set to {args.optimizer.lower()}"
834838
)
835839

@@ -853,6 +857,20 @@ def load_model_hook(models, input_dir):
853857
eps=args.adam_epsilon,
854858
)
855859

860+
elif args.optimizer.lower() == "ademamix":
861+
try:
862+
import bitsandbytes as bnb
863+
except ImportError:
864+
raise ImportError(
865+
"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
866+
)
867+
if args.use_8bit_ademamix:
868+
optimizer_class = bnb.optim.AdEMAMix8bit
869+
else:
870+
optimizer_class = bnb.optim.AdEMAMix
871+
872+
optimizer = optimizer_class(params_to_optimize)
873+
856874
if args.optimizer.lower() == "prodigy":
857875
try:
858876
import prodigyopt
@@ -868,7 +886,6 @@ def load_model_hook(models, input_dir):
868886

869887
optimizer = optimizer_class(
870888
params_to_optimize,
871-
lr=args.learning_rate,
872889
betas=(args.adam_beta1, args.adam_beta2),
873890
beta3=args.prodigy_beta3,
874891
weight_decay=args.adam_weight_decay,
@@ -1020,12 +1037,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10201037
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
10211038
model_input = model_input.to(dtype=weight_dtype)
10221039

1023-
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
1040+
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
10241041

10251042
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
10261043
model_input.shape[0],
1027-
model_input.shape[2],
1028-
model_input.shape[3],
1044+
model_input.shape[2] // 2,
1045+
model_input.shape[3] // 2,
10291046
accelerator.device,
10301047
weight_dtype,
10311048
)
@@ -1059,7 +1076,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10591076
)
10601077

10611078
# handle guidance
1062-
if transformer.config.guidance_embeds:
1079+
if unwrap_model(transformer).config.guidance_embeds:
10631080
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
10641081
guidance = guidance.expand(model_input.shape[0])
10651082
else:
@@ -1082,8 +1099,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10821099
)[0]
10831100
model_pred = FluxPipeline._unpack_latents(
10841101
model_pred,
1085-
height=int(model_input.shape[2] * vae_scale_factor / 2),
1086-
width=int(model_input.shape[3] * vae_scale_factor / 2),
1102+
height=model_input.shape[2] * vae_scale_factor,
1103+
width=model_input.shape[3] * vae_scale_factor,
10871104
vae_scale_factor=vae_scale_factor,
10881105
)
10891106

0 commit comments

Comments
 (0)