Skip to content

Commit 4690db2

Browse files
raulc0399yiyixuxusayakpaullawrence-cj
committed
adds the pipeline for pixart alpha controlnet (#8857)
* add the controlnet pipeline for pixart alpha --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: junsongc <[email protected]>
1 parent 5905401 commit 4690db2

File tree

8 files changed

+2778
-0
lines changed

8 files changed

+2778
-0
lines changed

examples/community/README.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
7373
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
7474
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
7575
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
76+
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
7677
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
7778
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
7879

@@ -4445,3 +4446,94 @@ grid_image.save(grid_dir + "sample.png")
44454446
`pag_scale` : guidance scale of PAG (ex: 5.0)
44464447

44474448
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
4449+
4450+
# PIXART-α Controlnet pipeline
4451+
4452+
[Project](https://pixart-alpha.github.io/) / [GitHub](https://github.com/PixArt-alpha/PixArt-alpha/blob/master/asset/docs/pixart_controlnet.md)
4453+
4454+
This the implementation of the controlnet model and the pipelne for the Pixart-alpha model, adapted to use the HuggingFace Diffusers.
4455+
4456+
## Example Usage
4457+
4458+
This example uses the Pixart HED Controlnet model, converted from the control net model as trained by the authors of the paper.
4459+
4460+
```py
4461+
import sys
4462+
import os
4463+
import torch
4464+
import torchvision.transforms as T
4465+
import torchvision.transforms.functional as TF
4466+
4467+
from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
4468+
from diffusers.utils import load_image
4469+
4470+
from diffusers.image_processor import PixArtImageProcessor
4471+
4472+
from controlnet_aux import HEDdetector
4473+
4474+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4475+
from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
4476+
4477+
controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
4478+
4479+
weight_dtype = torch.float16
4480+
image_size = 1024
4481+
4482+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4483+
4484+
torch.manual_seed(0)
4485+
4486+
# load controlnet
4487+
controlnet = PixArtControlNetAdapterModel.from_pretrained(
4488+
controlnet_repo_id,
4489+
torch_dtype=weight_dtype,
4490+
use_safetensors=True,
4491+
).to(device)
4492+
4493+
pipe = PixArtAlphaControlnetPipeline.from_pretrained(
4494+
"PixArt-alpha/PixArt-XL-2-1024-MS",
4495+
controlnet=controlnet,
4496+
torch_dtype=weight_dtype,
4497+
use_safetensors=True,
4498+
).to(device)
4499+
4500+
images_path = "images"
4501+
control_image_file = "0_7.jpg"
4502+
4503+
prompt = "battleship in space, galaxy in background"
4504+
4505+
control_image_name = control_image_file.split('.')[0]
4506+
4507+
control_image = load_image(f"{images_path}/{control_image_file}")
4508+
print(control_image.size)
4509+
height, width = control_image.size
4510+
4511+
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
4512+
4513+
condition_transform = T.Compose([
4514+
T.Lambda(lambda img: img.convert('RGB')),
4515+
T.CenterCrop([image_size, image_size]),
4516+
])
4517+
4518+
control_image = condition_transform(control_image)
4519+
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
4520+
4521+
hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
4522+
4523+
# run pipeline
4524+
with torch.no_grad():
4525+
out = pipe(
4526+
prompt=prompt,
4527+
image=hed_edge,
4528+
num_inference_steps=14,
4529+
guidance_scale=4.5,
4530+
height=image_size,
4531+
width=image_size,
4532+
)
4533+
4534+
out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
4535+
4536+
```
4537+
4538+
In the folder examples/pixart there is also a script that can be used to train new models.
4539+
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
images/
2+
output/

0 commit comments

Comments
 (0)