Skip to content

Commit 9fb4880

Browse files
committed
fix: missing AutoencoderKL lora adapter
1 parent 0d1d267 commit 9fb4880

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
)
3232
from ..modeling_outputs import AutoencoderKLOutput
3333
from ..modeling_utils import ModelMixin
34+
from ...loaders import PeftAdapterMixin
3435
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3536

3637

37-
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
38+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
3839
r"""
3940
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4041

tests/models/autoencoders/test_models_vae.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from diffusers.utils.torch_utils import randn_tensor
5050

5151
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
52-
52+
from peft import LoraConfig
5353

5454
enable_full_determinism()
5555

@@ -299,7 +299,38 @@ def test_output_pretrained(self):
299299

300300
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
301301

302+
def test_lora_adapter(self):
303+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
304+
vae = self.model_class(**init_dict)
305+
306+
target_modules_vae = [
307+
"conv1",
308+
"conv2",
309+
"conv_in",
310+
"conv_shortcut",
311+
"conv",
312+
"conv_out",
313+
"skip_conv_1",
314+
"skip_conv_2",
315+
"skip_conv_3",
316+
"skip_conv_4",
317+
"to_k",
318+
"to_q",
319+
"to_v",
320+
"to_out.0",
321+
]
322+
vae_lora_config = LoraConfig(
323+
r=16,
324+
init_lora_weights="gaussian",
325+
target_modules=target_modules_vae,
326+
)
327+
328+
vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
329+
active_lora = vae.active_adapters()
330+
self.assertTrue(len(active_lora) == 1)
331+
self.assertTrue(active_lora[0] == "vae_lora")
302332

333+
303334
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
304335
model_class = AsymmetricAutoencoderKL
305336
main_input_name = "sample"

0 commit comments

Comments
 (0)