Skip to content

Commit 7025526

Browse files
committed
fix: missing AutoencoderKL lora adapter
1 parent 0d1d267 commit 7025526

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.nn as nn
1818

1919
from ...configuration_utils import ConfigMixin, register_to_config
20+
from ...loaders import PeftAdapterMixin
2021
from ...loaders.single_file_model import FromOriginalModelMixin
2122
from ...utils import deprecate
2223
from ...utils.accelerate_utils import apply_forward_hook
@@ -34,7 +35,7 @@
3435
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3536

3637

37-
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
38+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
3839
r"""
3940
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4041

tests/models/autoencoders/test_models_vae.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from datasets import load_dataset
2222
from parameterized import parameterized
2323

24+
25+
if is_peft_available():
26+
from peft import LoraConfig
27+
2428
from diffusers import (
2529
AsymmetricAutoencoderKL,
2630
AutoencoderKL,
@@ -36,6 +40,7 @@
3640
backend_empty_cache,
3741
enable_full_determinism,
3842
floats_tensor,
43+
is_peft_available,
3944
load_hf_numpy,
4045
require_torch_accelerator,
4146
require_torch_accelerator_with_fp16,
@@ -299,6 +304,38 @@ def test_output_pretrained(self):
299304

300305
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
301306

307+
@require_peft_backend
308+
def test_lora_adapter(self):
309+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
310+
vae = self.model_class(**init_dict)
311+
312+
target_modules_vae = [
313+
"conv1",
314+
"conv2",
315+
"conv_in",
316+
"conv_shortcut",
317+
"conv",
318+
"conv_out",
319+
"skip_conv_1",
320+
"skip_conv_2",
321+
"skip_conv_3",
322+
"skip_conv_4",
323+
"to_k",
324+
"to_q",
325+
"to_v",
326+
"to_out.0",
327+
]
328+
vae_lora_config = LoraConfig(
329+
r=16,
330+
init_lora_weights="gaussian",
331+
target_modules=target_modules_vae,
332+
)
333+
334+
vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
335+
active_lora = vae.active_adapters()
336+
self.assertTrue(len(active_lora) == 1)
337+
self.assertTrue(active_lora[0] == "vae_lora")
338+
302339

303340
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
304341
model_class = AsymmetricAutoencoderKL

0 commit comments

Comments
 (0)