Skip to content

[Flux LoRA] support parsing alpha from a flux lora state dict. #9236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,10 +1495,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):

@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
return_alphas: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this argument there to prevent breaking? i.e. load_state_dict() would be used outside our load_lora_weights method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We use lora_state_dict() in the training (example usage outside of load_lora_weights()).

**kwargs,
):
r"""
Expand Down Expand Up @@ -1583,7 +1583,26 @@ def lora_state_dict(
allow_pickle=allow_pickle,
)

return state_dict
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
network_alphas = {}
for k in keys:
if "alpha" in k:
alpha_value = state_dict.get(k)
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
alpha_value, float
):
network_alphas[k] = state_dict.pop(k)
else:
raise ValueError(
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
)

if return_alphas:
return state_dict, network_alphas
else:
return state_dict

def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
Expand Down Expand Up @@ -1617,14 +1636,17 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)

is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

self.load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
Expand All @@ -1634,7 +1656,7 @@ def load_lora_weights(
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=None,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
Expand All @@ -1643,8 +1665,7 @@ def load_lora_weights(
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.

Expand All @@ -1653,6 +1674,10 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
transformer (`SD3Transformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Expand Down Expand Up @@ -1684,7 +1709,12 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
if "lora_B" in key:
rank[key] = val.shape[1]

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if network_alphas is not None and len(network_alphas) >= 1:
prefix = cls.transformer_name
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
Expand Down
59 changes: 57 additions & 2 deletions tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tempfile
import unittest

import numpy as np
import safetensors.torch
import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel

from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device


if is_peft_available():
from peft.utils import get_peft_model_state_dict

sys.path.append(".")

from utils import PeftLoraLoaderMixinTests # noqa: E402
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402


@require_peft_backend
Expand Down Expand Up @@ -90,3 +97,51 @@ def get_dummy_inputs(self, with_generator=True):
pipeline_inputs.update({"generator": generator})

return noise, input_ids, pipeline_inputs

def test_with_alpha_in_state_dict(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)

pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")

images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images

with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)

self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))

# modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it also be possible to update the scales in this model according to the alphas below and create an image, then below assert that this image is identical to images_lora_with_alpha?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand it. Would it be possible to provide a schematic of what you mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IIUC, you dump the state dict and then edit the alphas in the state dict, then load the state dict with edited alphas below and check that the image has changed. My suggestion is that the scalings of the loras of the transformer where edited here, using the same random changes as in the state dict, we could create an image with these altered alphas. Then below we can assert that this image should be identical to images_lora_with_alpha.

My reasoning for this is that a check that the image is unequal is always a bit weaker than an equals check, so the test would be stronger. However, I see how this is extra effort, so totally understand if it's not worth it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, happy to give this a try but I still don't understand how we can update the scaling. How are you envisioning that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yes, I checked and it seems there is no easy way. It is possible to pass a scale argument to forward, but this would always be the same value. Up to you if you think this could be a worth adding to the test or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably configure the alpha value in the LoraConfig and then generate images with that config but I think it's okay to leave that for now since we're testing it here: #9143. When either of the two PRs is merged, we can revisit. What say?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

for k, v in state_dict_with_alpha.items():
# only do for `transformer` and for the k projections -- should be enough to test.
if "transformer" in k and "to_k" in k and "lora_A" in k:
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
state_dict_with_alpha.update(alpha_dict)

images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")

pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict_with_alpha)
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images

self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
Loading