Skip to content

Commit e92bbf4

Browse files
sayakpaulDN6
andcommitted
[Core] introduce controlnet module (#8768)
* move vae flax module. * controlnet module. * prepare for PR. * revert a commit * gracefully deprecate controlnet deps. * fix * fix doc path * fix-copies * fix path * style * style * conflicts * fix * fix-copies * sparsectrl. * updates * fix * updates * updates * updates * fix --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 221d6db commit e92bbf4

26 files changed

+2970
-2752
lines changed

docs/source/en/api/models/controlnet.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
3939

4040
## ControlNetOutput
4141

42-
[[autodoc]] models.controlnet.ControlNetOutput
42+
[[autodoc]] models.controlnets.controlnet.ControlNetOutput
4343

4444
## FlaxControlNetModel
4545

4646
[[autodoc]] FlaxControlNetModel
4747

4848
## FlaxControlNetOutput
4949

50-
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
50+
[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput

docs/source/en/api/models/controlnet_sd3.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ pipe = StableDiffusion3ControlNetPipeline.from_pretrained("stabilityai/stable-di
3838

3939
## SD3ControlNetOutput
4040

41-
[[autodoc]] models.controlnet_sd3.SD3ControlNetOutput
41+
[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput
4242

examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,11 @@ def forward(
229229
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
230230
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
231231
return_dict (`bool`, defaults to `True`):
232-
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
232+
Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple.
233233
234234
Returns:
235-
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
236-
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
235+
[`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:
236+
If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
237237
returned where the first element is the sample tensor.
238238
"""
239239
# check channel order

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@
487487

488488

489489
else:
490-
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
490+
_import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
491491
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
492492
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
493493
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
@@ -914,7 +914,7 @@
914914
except OptionalDependencyNotAvailable:
915915
from .utils.dummy_flax_objects import * # noqa F403
916916
else:
917-
from .models.controlnet_flax import FlaxControlNetModel
917+
from .models.controlnets.controlnet_flax import FlaxControlNetModel
918918
from .models.modeling_flax_utils import FlaxModelMixin
919919
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
920920
from .models.vae_flax import FlaxAutoencoderKL

src/diffusers/models/__init__.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,16 @@
3636
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
3737
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
3838
_import_structure["autoencoders.vq_model"] = ["VQModel"]
39-
_import_structure["controlnet"] = ["ControlNetModel"]
40-
_import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
41-
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
42-
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
43-
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
44-
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
39+
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
40+
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
41+
_import_structure["controlnets.controlnet_hunyuan"] = [
42+
"HunyuanDiT2DControlNetModel",
43+
"HunyuanDiT2DMultiControlNetModel",
44+
]
45+
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
46+
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
47+
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
48+
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
4549
_import_structure["embeddings"] = ["ImageProjection"]
4650
_import_structure["modeling_utils"] = ["ModelMixin"]
4751
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
@@ -74,7 +78,7 @@
7478
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
7579

7680
if is_flax_available():
77-
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
81+
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
7882
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
7983
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
8084

@@ -94,12 +98,19 @@
9498
ConsistencyDecoderVAE,
9599
VQModel,
96100
)
97-
from .controlnet import ControlNetModel
98-
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
99-
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
100-
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
101-
from .controlnet_sparsectrl import SparseControlNetModel
102-
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
101+
from .controlnets import (
102+
ControlNetModel,
103+
ControlNetXSAdapter,
104+
FluxControlNetModel,
105+
FluxMultiControlNetModel,
106+
HunyuanDiT2DControlNetModel,
107+
HunyuanDiT2DMultiControlNetModel,
108+
MultiControlNetModel,
109+
SD3ControlNetModel,
110+
SD3MultiControlNetModel,
111+
SparseControlNetModel,
112+
UNetControlNetXSModel,
113+
)
103114
from .embeddings import ImageProjection
104115
from .modeling_utils import ModelMixin
105116
from .transformers import (
@@ -137,7 +148,7 @@
137148
)
138149

139150
if is_flax_available():
140-
from .controlnet_flax import FlaxControlNetModel
151+
from .controlnets import FlaxControlNetModel
141152
from .unets import FlaxUNet2DConditionModel
142153
from .vae_flax import FlaxAutoencoderKL
143154

0 commit comments

Comments
 (0)