@@ -201,6 +201,32 @@ def get_dummy_tokens(self):
201
201
prepared_inputs ["input_ids" ] = inputs
202
202
return prepared_inputs
203
203
204
+ def _get_lora_state_dicts (self , modules_to_save ):
205
+ state_dicts = {}
206
+ for module_name , module in modules_to_save .items ():
207
+ if module is not None :
208
+ state_dicts [f"{ module_name } _lora_layers" ] = get_peft_model_state_dict (module )
209
+ return state_dicts
210
+
211
+ def _get_modules_to_save (self , pipe , has_denoiser = False ):
212
+ modules_to_save = {}
213
+ lora_loadable_modules = self .pipeline_class ._lora_loadable_modules
214
+
215
+ if "text_encoder" in lora_loadable_modules and hasattr (pipe , "text_encoder" ):
216
+ modules_to_save ["text_encoder" ] = pipe .text_encoder
217
+
218
+ if "text_encoder_2" in lora_loadable_modules and hasattr (pipe , "text_encoder_2" ):
219
+ modules_to_save ["text_encoder_2" ] = pipe .text_encoder_2
220
+
221
+ if has_denoiser :
222
+ if "unet" in lora_loadable_modules and hasattr (pipe , "unet" ):
223
+ modules_to_save ["unet" ] = pipe .unet
224
+
225
+ if "transformer" in lora_loadable_modules and hasattr (pipe , "transformer" ):
226
+ modules_to_save ["transformer" ] = pipe .transformer
227
+
228
+ return modules_to_save
229
+
204
230
def test_simple_inference (self ):
205
231
"""
206
232
Tests a simple inference and makes sure it works as expected
@@ -420,45 +446,21 @@ def test_simple_inference_with_text_lora_save_load(self):
420
446
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
421
447
422
448
with tempfile .TemporaryDirectory () as tmpdirname :
423
- text_encoder_state_dict = get_peft_model_state_dict (pipe .text_encoder )
424
- if self .has_two_text_encoders or self .has_three_text_encoders :
425
- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
426
- text_encoder_2_state_dict = get_peft_model_state_dict (pipe .text_encoder_2 )
427
-
428
- self .pipeline_class .save_lora_weights (
429
- save_directory = tmpdirname ,
430
- text_encoder_lora_layers = text_encoder_state_dict ,
431
- text_encoder_2_lora_layers = text_encoder_2_state_dict ,
432
- safe_serialization = False ,
433
- )
434
- else :
435
- self .pipeline_class .save_lora_weights (
436
- save_directory = tmpdirname ,
437
- text_encoder_lora_layers = text_encoder_state_dict ,
438
- safe_serialization = False ,
439
- )
449
+ modules_to_save = self ._get_modules_to_save (pipe )
450
+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
440
451
441
- if self .has_two_text_encoders :
442
- if "text_encoder_2" not in self .pipeline_class ._lora_loadable_modules :
443
- self .pipeline_class .save_lora_weights (
444
- save_directory = tmpdirname ,
445
- text_encoder_lora_layers = text_encoder_state_dict ,
446
- safe_serialization = False ,
447
- )
452
+ self .pipeline_class .save_lora_weights (
453
+ save_directory = tmpdirname , safe_serialization = False , ** lora_state_dicts
454
+ )
448
455
449
456
self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.bin" )))
450
457
pipe .unload_lora_weights ()
451
-
452
458
pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.bin" ))
453
459
454
- images_lora_from_pretrained = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
455
- self .assertTrue (check_if_lora_correctly_set (pipe . text_encoder ), "Lora not correctly set in text encoder " )
460
+ for module_name , module in modules_to_save . items ():
461
+ self .assertTrue (check_if_lora_correctly_set (module ), f "Lora not correctly set in { module_name } " )
456
462
457
- if self .has_two_text_encoders or self .has_three_text_encoders :
458
- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
459
- self .assertTrue (
460
- check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
461
- )
463
+ images_lora_from_pretrained = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
462
464
463
465
self .assertTrue (
464
466
np .allclose (images_lora , images_lora_from_pretrained , atol = 1e-3 , rtol = 1e-3 ),
@@ -614,54 +616,20 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
614
616
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
615
617
616
618
with tempfile .TemporaryDirectory () as tmpdirname :
617
- text_encoder_state_dict = (
618
- get_peft_model_state_dict ( pipe . text_encoder )
619
- if "text_encoder" in self .pipeline_class ._lora_loadable_modules
620
- else None
619
+ modules_to_save = self . _get_modules_to_save ( pipe , has_denoiser = True )
620
+ lora_state_dicts = self . _get_lora_state_dicts ( modules_to_save )
621
+ self .pipeline_class .save_lora_weights (
622
+ save_directory = tmpdirname , safe_serialization = False , ** lora_state_dicts
621
623
)
622
624
623
- denoiser_state_dict = get_peft_model_state_dict (denoiser )
624
-
625
- saving_kwargs = {
626
- "save_directory" : tmpdirname ,
627
- "safe_serialization" : False ,
628
- }
629
-
630
- if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
631
- saving_kwargs .update ({"text_encoder_lora_layers" : text_encoder_state_dict })
632
-
633
- if self .unet_kwargs is not None :
634
- saving_kwargs .update ({"unet_lora_layers" : denoiser_state_dict })
635
- else :
636
- saving_kwargs .update ({"transformer_lora_layers" : denoiser_state_dict })
637
-
638
- if self .has_two_text_encoders or self .has_three_text_encoders :
639
- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
640
- text_encoder_2_state_dict = get_peft_model_state_dict (pipe .text_encoder_2 )
641
- saving_kwargs .update ({"text_encoder_2_lora_layers" : text_encoder_2_state_dict })
642
-
643
- self .pipeline_class .save_lora_weights (** saving_kwargs )
644
-
645
625
self .assertTrue (os .path .isfile (os .path .join (tmpdirname , "pytorch_lora_weights.bin" )))
646
626
pipe .unload_lora_weights ()
647
-
648
627
pipe .load_lora_weights (os .path .join (tmpdirname , "pytorch_lora_weights.bin" ))
649
628
650
- images_lora_from_pretrained = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
651
-
652
- if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
653
- self .assertTrue (
654
- check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder"
655
- )
656
-
657
- self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser" )
658
-
659
- if self .has_two_text_encoders or self .has_three_text_encoders :
660
- if "text_encoder_2" in self .pipeline_class ._lora_loadable_modules :
661
- self .assertTrue (
662
- check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
663
- )
629
+ for module_name , module in modules_to_save .items ():
630
+ self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
664
631
632
+ images_lora_from_pretrained = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
665
633
self .assertTrue (
666
634
np .allclose (images_lora , images_lora_from_pretrained , atol = 1e-3 , rtol = 1e-3 ),
667
635
"Loading from saved checkpoints should give same results." ,
0 commit comments