-
Notifications
You must be signed in to change notification settings - Fork 6k
[training] fixes to the quantization training script and add AdEMAMix optimizer as an option #9806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -349,14 +349,19 @@ def parse_args(input_args=None): | |
"--optimizer", | ||
type=str, | ||
default="AdamW", | ||
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), | ||
choices=["AdamW", "Prodigy", "AdEMAMix"], | ||
) | ||
|
||
parser.add_argument( | ||
"--use_8bit_adam", | ||
action="store_true", | ||
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", | ||
) | ||
parser.add_argument( | ||
"--use_8bit_ademamix", | ||
action="store_true", | ||
help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.", | ||
) | ||
|
||
parser.add_argument( | ||
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." | ||
|
@@ -820,16 +825,15 @@ def load_model_hook(models, input_dir): | |
params_to_optimize = [transformer_parameters_with_lr] | ||
|
||
# Optimizer creation | ||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): | ||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw": | ||
logger.warning( | ||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." | ||
"Defaulting to adamW" | ||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " | ||
f"set to {args.optimizer.lower()}" | ||
) | ||
args.optimizer = "adamw" | ||
|
||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw": | ||
if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix": | ||
logger.warning( | ||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " | ||
f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was " | ||
f"set to {args.optimizer.lower()}" | ||
) | ||
|
||
|
@@ -853,6 +857,20 @@ def load_model_hook(models, input_dir): | |
eps=args.adam_epsilon, | ||
) | ||
|
||
elif args.optimizer.lower() == "ademamix": | ||
try: | ||
import bitsandbytes as bnb | ||
except ImportError: | ||
raise ImportError( | ||
"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`." | ||
) | ||
if args.use_8bit_ademamix: | ||
optimizer_class = bnb.optim.AdEMAMix8bit | ||
else: | ||
optimizer_class = bnb.optim.AdEMAMix | ||
|
||
optimizer = optimizer_class(params_to_optimize) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we support
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Umm I didn't want to actually to keep the separations of concern very clear. We could maybe revisit if the community finds the optimizer worth the go? |
||
|
||
if args.optimizer.lower() == "prodigy": | ||
try: | ||
import prodigyopt | ||
|
@@ -868,7 +886,6 @@ def load_model_hook(models, input_dir): | |
|
||
optimizer = optimizer_class( | ||
params_to_optimize, | ||
lr=args.learning_rate, | ||
betas=(args.adam_beta1, args.adam_beta2), | ||
beta3=args.prodigy_beta3, | ||
weight_decay=args.adam_weight_decay, | ||
|
@@ -1020,12 +1037,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | |
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor | ||
model_input = model_input.to(dtype=weight_dtype) | ||
|
||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) | ||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) | ||
|
||
latent_image_ids = FluxPipeline._prepare_latent_image_ids( | ||
model_input.shape[0], | ||
model_input.shape[2], | ||
model_input.shape[3], | ||
model_input.shape[2] // 2, | ||
model_input.shape[3] // 2, | ||
accelerator.device, | ||
weight_dtype, | ||
Comment on lines
+1040
to
1047
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Follows what we do in the Flux LoRA scripts. |
||
) | ||
|
@@ -1059,7 +1076,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | |
) | ||
|
||
# handle guidance | ||
if transformer.config.guidance_embeds: | ||
if unwrap_model(transformer).config.guidance_embeds: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So that things are compatible with DeepSpeed. |
||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device) | ||
guidance = guidance.expand(model_input.shape[0]) | ||
else: | ||
|
@@ -1082,8 +1099,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | |
)[0] | ||
model_pred = FluxPipeline._unpack_latents( | ||
model_pred, | ||
height=int(model_input.shape[2] * vae_scale_factor / 2), | ||
width=int(model_input.shape[3] * vae_scale_factor / 2), | ||
height=model_input.shape[2] * vae_scale_factor, | ||
width=model_input.shape[3] * vae_scale_factor, | ||
Comment on lines
+1102
to
+1103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
vae_scale_factor=vae_scale_factor, | ||
) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.