@@ -349,14 +349,19 @@ def parse_args(input_args=None):
349
349
"--optimizer" ,
350
350
type = str ,
351
351
default = "AdamW" ,
352
- help = ( 'The optimizer type to use. Choose between ["AdamW", "prodigy"]' ) ,
352
+ choices = ["AdamW" , "Prodigy" , "AdEMAMix" ] ,
353
353
)
354
354
355
355
parser .add_argument (
356
356
"--use_8bit_adam" ,
357
357
action = "store_true" ,
358
358
help = "Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW" ,
359
359
)
360
+ parser .add_argument (
361
+ "--use_8bit_ademamix" ,
362
+ action = "store_true" ,
363
+ help = "Whether or not to use 8-bit AdEMAMix from bitsandbytes." ,
364
+ )
360
365
361
366
parser .add_argument (
362
367
"--adam_beta1" , type = float , default = 0.9 , help = "The beta1 parameter for the Adam and Prodigy optimizers."
@@ -820,16 +825,15 @@ def load_model_hook(models, input_dir):
820
825
params_to_optimize = [transformer_parameters_with_lr ]
821
826
822
827
# Optimizer creation
823
- if not ( args .optimizer . lower () == "prodigy" or args .optimizer .lower () == "adamw" ) :
828
+ if args .use_8bit_adam and not args .optimizer .lower () == "adamw" :
824
829
logger .warning (
825
- f"Unsupported choice of optimizer: { args . optimizer } .Supported optimizers include [adamW, prodigy]. "
826
- "Defaulting to adamW "
830
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
831
+ f"set to { args . optimizer . lower () } "
827
832
)
828
- args .optimizer = "adamw"
829
833
830
- if args .use_8bit_adam and not args .optimizer .lower () == "adamw " :
834
+ if args .use_8bit_ademamix and not args .optimizer .lower () == "ademamix " :
831
835
logger .warning (
832
- f"use_8bit_adam is ignored when optimizer is not set to 'AdamW '. Optimizer was "
836
+ f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix '. Optimizer was "
833
837
f"set to { args .optimizer .lower ()} "
834
838
)
835
839
@@ -853,6 +857,20 @@ def load_model_hook(models, input_dir):
853
857
eps = args .adam_epsilon ,
854
858
)
855
859
860
+ elif args .optimizer .lower () == "ademamix" :
861
+ try :
862
+ import bitsandbytes as bnb
863
+ except ImportError :
864
+ raise ImportError (
865
+ "To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
866
+ )
867
+ if args .use_8bit_ademamix :
868
+ optimizer_class = bnb .optim .AdEMAMix8bit
869
+ else :
870
+ optimizer_class = bnb .optim .AdEMAMix
871
+
872
+ optimizer = optimizer_class (params_to_optimize )
873
+
856
874
if args .optimizer .lower () == "prodigy" :
857
875
try :
858
876
import prodigyopt
@@ -868,7 +886,6 @@ def load_model_hook(models, input_dir):
868
886
869
887
optimizer = optimizer_class (
870
888
params_to_optimize ,
871
- lr = args .learning_rate ,
872
889
betas = (args .adam_beta1 , args .adam_beta2 ),
873
890
beta3 = args .prodigy_beta3 ,
874
891
weight_decay = args .adam_weight_decay ,
@@ -1020,12 +1037,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1020
1037
model_input = (model_input - vae_config_shift_factor ) * vae_config_scaling_factor
1021
1038
model_input = model_input .to (dtype = weight_dtype )
1022
1039
1023
- vae_scale_factor = 2 ** (len (vae_config_block_out_channels ))
1040
+ vae_scale_factor = 2 ** (len (vae_config_block_out_channels ) - 1 )
1024
1041
1025
1042
latent_image_ids = FluxPipeline ._prepare_latent_image_ids (
1026
1043
model_input .shape [0 ],
1027
- model_input .shape [2 ],
1028
- model_input .shape [3 ],
1044
+ model_input .shape [2 ] // 2 ,
1045
+ model_input .shape [3 ] // 2 ,
1029
1046
accelerator .device ,
1030
1047
weight_dtype ,
1031
1048
)
@@ -1059,7 +1076,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1059
1076
)
1060
1077
1061
1078
# handle guidance
1062
- if transformer .config .guidance_embeds :
1079
+ if unwrap_model ( transformer ) .config .guidance_embeds :
1063
1080
guidance = torch .tensor ([args .guidance_scale ], device = accelerator .device )
1064
1081
guidance = guidance .expand (model_input .shape [0 ])
1065
1082
else :
@@ -1082,8 +1099,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1082
1099
)[0 ]
1083
1100
model_pred = FluxPipeline ._unpack_latents (
1084
1101
model_pred ,
1085
- height = int ( model_input .shape [2 ] * vae_scale_factor / 2 ) ,
1086
- width = int ( model_input .shape [3 ] * vae_scale_factor / 2 ) ,
1102
+ height = model_input .shape [2 ] * vae_scale_factor ,
1103
+ width = model_input .shape [3 ] * vae_scale_factor ,
1087
1104
vae_scale_factor = vae_scale_factor ,
1088
1105
)
1089
1106
0 commit comments