Skip to content

Commit e1ea942

Browse files
vladmandicyiyixuxu
authored andcommitted
Several fixes to Flux ControlNet pipelines (#9472)
* fix flux controlnet pipelines --------- Co-authored-by: yiyixuxu <[email protected]>
1 parent fd15a9d commit e1ea942

File tree

4 files changed

+25
-13
lines changed

4 files changed

+25
-13
lines changed

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@
2929
StableDiffusionXLControlNetPipeline,
3030
)
3131
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
32-
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
32+
from .flux import (
33+
FluxControlNetImg2ImgPipeline,
34+
FluxControlNetInpaintPipeline,
35+
FluxControlNetPipeline,
36+
FluxImg2ImgPipeline,
37+
FluxInpaintPipeline,
38+
FluxPipeline,
39+
)
3340
from .hunyuandit import HunyuanDiTPipeline
3441
from .kandinsky import (
3542
KandinskyCombinedPipeline,
@@ -128,6 +135,7 @@
128135
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
129136
("lcm", LatentConsistencyModelImg2ImgPipeline),
130137
("flux", FluxImg2ImgPipeline),
138+
("flux-controlnet", FluxControlNetImg2ImgPipeline),
131139
]
132140
)
133141

@@ -143,6 +151,7 @@
143151
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
144152
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
145153
("flux", FluxInpaintPipeline),
154+
("flux-controlnet", FluxControlNetInpaintPipeline),
146155
]
147156
)
148157

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def __call__(
729729
batch_size=batch_size * num_images_per_prompt,
730730
num_images_per_prompt=num_images_per_prompt,
731731
device=device,
732-
dtype=dtype,
732+
dtype=self.vae.dtype,
733733
)
734734
height, width = control_image.shape[-2:]
735735

@@ -763,7 +763,7 @@ def __call__(
763763
batch_size=batch_size * num_images_per_prompt,
764764
num_images_per_prompt=num_images_per_prompt,
765765
device=device,
766-
dtype=dtype,
766+
dtype=self.vae.dtype,
767767
)
768768
height, width = control_image_.shape[-2:]
769769

@@ -840,12 +840,10 @@ def __call__(
840840
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
841841
timestep = t.expand(latents.shape[0]).to(latents.dtype)
842842

843-
# handle guidance
844-
if self.transformer.config.guidance_embeds:
845-
guidance = torch.tensor([guidance_scale], device=device)
846-
guidance = guidance.expand(latents.shape[0])
847-
else:
848-
guidance = None
843+
guidance = (
844+
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
845+
)
846+
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
849847

850848
# controlnet
851849
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
@@ -863,6 +861,11 @@ def __call__(
863861
return_dict=False,
864862
)
865863

864+
guidance = (
865+
torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
866+
)
867+
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
868+
866869
noise_pred = self.transformer(
867870
hidden_states=latents,
868871
timestep=timestep / 1000,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def __call__(
767767
batch_size=batch_size * num_images_per_prompt,
768768
num_images_per_prompt=num_images_per_prompt,
769769
device=device,
770-
dtype=dtype,
770+
dtype=self.vae.dtype,
771771
)
772772
height, width = control_image.shape[-2:]
773773

@@ -798,7 +798,7 @@ def __call__(
798798
batch_size=batch_size * num_images_per_prompt,
799799
num_images_per_prompt=num_images_per_prompt,
800800
device=device,
801-
dtype=dtype,
801+
dtype=self.vae.dtype,
802802
)
803803
height, width = control_image_.shape[-2:]
804804

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ def __call__(
899899
batch_size=batch_size * num_images_per_prompt,
900900
num_images_per_prompt=num_images_per_prompt,
901901
device=device,
902-
dtype=dtype,
902+
dtype=self.vae.dtype,
903903
)
904904
height, width = control_image.shape[-2:]
905905

@@ -933,7 +933,7 @@ def __call__(
933933
batch_size=batch_size * num_images_per_prompt,
934934
num_images_per_prompt=num_images_per_prompt,
935935
device=device,
936-
dtype=dtype,
936+
dtype=self.vae.dtype,
937937
)
938938
height, width = control_image_.shape[-2:]
939939

0 commit comments

Comments
 (0)