-
Notifications
You must be signed in to change notification settings - Fork 6k
[Flux LoRA] support parsing alpha from a flux lora state dict. #9236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f7e7091
2983543
c8bca51
5dd5f20
67fc491
27405bc
c5ae778
8501032
491a49c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,19 +12,26 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import sys | ||
import tempfile | ||
import unittest | ||
|
||
import numpy as np | ||
import safetensors.torch | ||
import torch | ||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel | ||
|
||
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel | ||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend | ||
from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device | ||
|
||
|
||
if is_peft_available(): | ||
from peft.utils import get_peft_model_state_dict | ||
|
||
sys.path.append(".") | ||
|
||
from utils import PeftLoraLoaderMixinTests # noqa: E402 | ||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 | ||
|
||
|
||
@require_peft_backend | ||
|
@@ -90,3 +97,51 @@ def get_dummy_inputs(self, with_generator=True): | |
pipeline_inputs.update({"generator": generator}) | ||
|
||
return noise, input_ids, pipeline_inputs | ||
|
||
def test_with_alpha_in_state_dict(self): | ||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) | ||
pipe = self.pipeline_class(**components) | ||
pipe = pipe.to(torch_device) | ||
pipe.set_progress_bar_config(disable=None) | ||
_, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images | ||
self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
||
pipe.transformer.add_adapter(denoiser_lora_config) | ||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") | ||
|
||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images | ||
|
||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) | ||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) | ||
|
||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) | ||
pipe.unload_lora_weights() | ||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) | ||
|
||
# modify the state dict to have alpha values following | ||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors | ||
state_dict_with_alpha = safetensors.torch.load_file( | ||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors") | ||
) | ||
alpha_dict = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it also be possible to update the scales in this model according to the alphas below and create an image, then below assert that this image is identical to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't fully understand it. Would it be possible to provide a schematic of what you mean? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So IIUC, you dump the state dict and then edit the alphas in the state dict, then load the state dict with edited alphas below and check that the image has changed. My suggestion is that the scalings of the loras of the transformer where edited here, using the same random changes as in the state dict, we could create an image with these altered alphas. Then below we can assert that this image should be identical to My reasoning for this is that a check that the image is unequal is always a bit weaker than an equals check, so the test would be stronger. However, I see how this is extra effort, so totally understand if it's not worth it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, happy to give this a try but I still don't understand how we can update the scaling. How are you envisioning that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, yes, I checked and it seems there is no easy way. It is possible to pass a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could probably configure the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good |
||
for k, v in state_dict_with_alpha.items(): | ||
# only do for `transformer` and for the k projections -- should be enough to test. | ||
if "transformer" in k and "to_k" in k and "lora_A" in k: | ||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=())) | ||
state_dict_with_alpha.update(alpha_dict) | ||
|
||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images | ||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser") | ||
|
||
pipe.unload_lora_weights() | ||
pipe.load_lora_weights(state_dict_with_alpha) | ||
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images | ||
|
||
self.assertTrue( | ||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), | ||
"Loading from saved checkpoints should give same results.", | ||
) | ||
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this argument there to prevent breaking? i.e.
load_state_dict()
would be used outside ourload_lora_weights
method?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We use
lora_state_dict()
in the training (example usage outside ofload_lora_weights()
).