Skip to content

Commit 9405b1a

Browse files
committed
[Tests] [LoRA] clean up the serialization stuff. (#9512)
* clean up the serialization stuff. * better
1 parent 1939920 commit 9405b1a

File tree

1 file changed

+41
-73
lines changed

1 file changed

+41
-73
lines changed

tests/lora/utils.py

Lines changed: 41 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,32 @@ def get_dummy_tokens(self):
201201
prepared_inputs["input_ids"] = inputs
202202
return prepared_inputs
203203

204+
def _get_lora_state_dicts(self, modules_to_save):
205+
state_dicts = {}
206+
for module_name, module in modules_to_save.items():
207+
if module is not None:
208+
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
209+
return state_dicts
210+
211+
def _get_modules_to_save(self, pipe, has_denoiser=False):
212+
modules_to_save = {}
213+
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
214+
215+
if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
216+
modules_to_save["text_encoder"] = pipe.text_encoder
217+
218+
if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
219+
modules_to_save["text_encoder_2"] = pipe.text_encoder_2
220+
221+
if has_denoiser:
222+
if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
223+
modules_to_save["unet"] = pipe.unet
224+
225+
if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
226+
modules_to_save["transformer"] = pipe.transformer
227+
228+
return modules_to_save
229+
204230
def test_simple_inference(self):
205231
"""
206232
Tests a simple inference and makes sure it works as expected
@@ -420,45 +446,21 @@ def test_simple_inference_with_text_lora_save_load(self):
420446
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
421447

422448
with tempfile.TemporaryDirectory() as tmpdirname:
423-
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
424-
if self.has_two_text_encoders or self.has_three_text_encoders:
425-
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
426-
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
427-
428-
self.pipeline_class.save_lora_weights(
429-
save_directory=tmpdirname,
430-
text_encoder_lora_layers=text_encoder_state_dict,
431-
text_encoder_2_lora_layers=text_encoder_2_state_dict,
432-
safe_serialization=False,
433-
)
434-
else:
435-
self.pipeline_class.save_lora_weights(
436-
save_directory=tmpdirname,
437-
text_encoder_lora_layers=text_encoder_state_dict,
438-
safe_serialization=False,
439-
)
449+
modules_to_save = self._get_modules_to_save(pipe)
450+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
440451

441-
if self.has_two_text_encoders:
442-
if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules:
443-
self.pipeline_class.save_lora_weights(
444-
save_directory=tmpdirname,
445-
text_encoder_lora_layers=text_encoder_state_dict,
446-
safe_serialization=False,
447-
)
452+
self.pipeline_class.save_lora_weights(
453+
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
454+
)
448455

449456
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
450457
pipe.unload_lora_weights()
451-
452458
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
453459

454-
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
455-
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
460+
for module_name, module in modules_to_save.items():
461+
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
456462

457-
if self.has_two_text_encoders or self.has_three_text_encoders:
458-
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
459-
self.assertTrue(
460-
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
461-
)
463+
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
462464

463465
self.assertTrue(
464466
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
@@ -614,54 +616,20 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
614616
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
615617

616618
with tempfile.TemporaryDirectory() as tmpdirname:
617-
text_encoder_state_dict = (
618-
get_peft_model_state_dict(pipe.text_encoder)
619-
if "text_encoder" in self.pipeline_class._lora_loadable_modules
620-
else None
619+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
620+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
621+
self.pipeline_class.save_lora_weights(
622+
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
621623
)
622624

623-
denoiser_state_dict = get_peft_model_state_dict(denoiser)
624-
625-
saving_kwargs = {
626-
"save_directory": tmpdirname,
627-
"safe_serialization": False,
628-
}
629-
630-
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
631-
saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict})
632-
633-
if self.unet_kwargs is not None:
634-
saving_kwargs.update({"unet_lora_layers": denoiser_state_dict})
635-
else:
636-
saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict})
637-
638-
if self.has_two_text_encoders or self.has_three_text_encoders:
639-
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
640-
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
641-
saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict})
642-
643-
self.pipeline_class.save_lora_weights(**saving_kwargs)
644-
645625
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
646626
pipe.unload_lora_weights()
647-
648627
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
649628

650-
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
651-
652-
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
653-
self.assertTrue(
654-
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
655-
)
656-
657-
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
658-
659-
if self.has_two_text_encoders or self.has_three_text_encoders:
660-
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
661-
self.assertTrue(
662-
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
663-
)
629+
for module_name, module in modules_to_save.items():
630+
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
664631

632+
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
665633
self.assertTrue(
666634
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
667635
"Loading from saved checkpoints should give same results.",

0 commit comments

Comments
 (0)