Skip to content

Commit cc7d88f

Browse files
ParagEkboteAMohamedAakhilsayakpaul
authored
Move IP Adapter Scripts to research project (#9960)
* Move files to research-projects. * docs: add IP Adapter training instructions * Delete venv * Update examples/ip_adapter/tutorial_train_sdxl.py Co-authored-by: Sayak Paul <[email protected]> * Cherry-picked commits and re-moved files to research_projects. * make style. * Update toctree and delete ip_adapter. * Nit Fix * Fix nit. * Fix nit. * Create training script for single GPU and set model format to .safetensors * Add sample inference script and restore _toctree * Restore toctree.yaml * fix spacing. * Update toctree.yaml --------- Co-authored-by: AMohamedAakhil <[email protected]> Co-authored-by: BootesVoid <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent ea40933 commit cc7d88f

File tree

6 files changed

+2032
-0
lines changed

6 files changed

+2032
-0
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# IP Adapter Training Example
2+
3+
[IP Adapter](https://arxiv.org/abs/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that "an image is worth a thousand words." By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources.
4+
5+
## Training locally with PyTorch
6+
7+
### Installing the dependencies
8+
9+
Before running the scripts, make sure to install the library's training dependencies:
10+
11+
**Important**
12+
13+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
14+
15+
```bash
16+
git clone https://github.com/huggingface/diffusers
17+
cd diffusers
18+
pip install -e .
19+
```
20+
21+
Then cd in the example folder and run
22+
23+
```bash
24+
pip install -r requirements.txt
25+
```
26+
27+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
28+
29+
```bash
30+
accelerate config
31+
```
32+
33+
Or for a default accelerate configuration without answering questions about your environment
34+
35+
```bash
36+
accelerate config default
37+
```
38+
39+
Or if your environment doesn't support an interactive shell e.g. a notebook
40+
41+
```python
42+
from accelerate.utils import write_basic_config
43+
write_basic_config()
44+
```
45+
46+
Certainly! Below is the documentation in pure Markdown format:
47+
48+
### Accelerate Launch Command Documentation
49+
50+
#### Description:
51+
The Accelerate launch command is used to train a model using multiple GPUs and mixed precision training. It launches the training script `tutorial_train_ip-adapter.py` with specified parameters and configurations.
52+
53+
#### Usage Example:
54+
55+
```
56+
accelerate launch --mixed_precision "fp16" \
57+
tutorial_train_ip-adapter.py \
58+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
59+
--image_encoder_path="{image_encoder_path}" \
60+
--data_json_file="{data.json}" \
61+
--data_root_path="{image_path}" \
62+
--mixed_precision="fp16" \
63+
--resolution=512 \
64+
--train_batch_size=8 \
65+
--dataloader_num_workers=4 \
66+
--learning_rate=1e-04 \
67+
--weight_decay=0.01 \
68+
--output_dir="{output_dir}" \
69+
--save_steps=10000
70+
```
71+
72+
### Multi-GPU Script:
73+
```
74+
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
75+
tutorial_train_ip-adapter.py \
76+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
77+
--image_encoder_path="{image_encoder_path}" \
78+
--data_json_file="{data.json}" \
79+
--data_root_path="{image_path}" \
80+
--mixed_precision="fp16" \
81+
--resolution=512 \
82+
--train_batch_size=8 \
83+
--dataloader_num_workers=4 \
84+
--learning_rate=1e-04 \
85+
--weight_decay=0.01 \
86+
--output_dir="{output_dir}" \
87+
--save_steps=10000
88+
```
89+
90+
#### Parameters:
91+
- `--num_processes`: Number of processes to launch for distributed training (in this example, 8 processes).
92+
- `--multi_gpu`: Flag indicating the usage of multiple GPUs for training.
93+
- `--mixed_precision "fp16"`: Enables mixed precision training with 16-bit floating-point precision.
94+
- `tutorial_train_ip-adapter.py`: Name of the training script to be executed.
95+
- `--pretrained_model_name_or_path`: Path or identifier for a pretrained model.
96+
- `--image_encoder_path`: Path to the CLIP image encoder.
97+
- `--data_json_file`: Path to the training data in JSON format.
98+
- `--data_root_path`: Root path where training images are located.
99+
- `--resolution`: Resolution of input images (512x512 in this example).
100+
- `--train_batch_size`: Batch size for training data (8 in this example).
101+
- `--dataloader_num_workers`: Number of subprocesses for data loading (4 in this example).
102+
- `--learning_rate`: Learning rate for training (1e-04 in this example).
103+
- `--weight_decay`: Weight decay for regularization (0.01 in this example).
104+
- `--output_dir`: Directory to save model checkpoints and predictions.
105+
- `--save_steps`: Frequency of saving checkpoints during training (10000 in this example).
106+
107+
### Inference
108+
109+
#### Description:
110+
The provided inference code is used to load a trained model checkpoint and extract the components related to image projection and IP (Image Processing) adapter. These components are then saved into a binary file for later use in inference.
111+
112+
#### Usage Example:
113+
```python
114+
from safetensors.torch import load_file, save_file
115+
116+
# Load the trained model checkpoint in safetensors format
117+
ckpt = "checkpoint-50000/pytorch_model.safetensors"
118+
sd = load_file(ckpt) # Using safetensors load function
119+
120+
# Extract image projection and IP adapter components
121+
image_proj_sd = {}
122+
ip_sd = {}
123+
124+
for k in sd:
125+
if k.startswith("unet"):
126+
pass # Skip unet-related keys
127+
elif k.startswith("image_proj_model"):
128+
image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
129+
elif k.startswith("adapter_modules"):
130+
ip_sd[k.replace("adapter_modules.", "")] = sd[k]
131+
132+
# Save the components into separate safetensors files
133+
save_file(image_proj_sd, "image_proj.safetensors")
134+
save_file(ip_sd, "ip_adapter.safetensors")
135+
```
136+
137+
### Sample Inference Script using the CLIP Model
138+
139+
```python
140+
141+
import torch
142+
from safetensors.torch import load_file
143+
from transformers import CLIPProcessor, CLIPModel # Using the Hugging Face CLIP model
144+
145+
# Load model components from safetensors
146+
image_proj_ckpt = "image_proj.safetensors"
147+
ip_adapter_ckpt = "ip_adapter.safetensors"
148+
149+
# Load the saved weights
150+
image_proj_sd = load_file(image_proj_ckpt)
151+
ip_adapter_sd = load_file(ip_adapter_ckpt)
152+
153+
# Define the model Parameters
154+
class ImageProjectionModel(torch.nn.Module):
155+
def __init__(self, input_dim=768, output_dim=512): # CLIP's default embedding size is 768
156+
super().__init__()
157+
self.model = torch.nn.Linear(input_dim, output_dim)
158+
159+
def forward(self, x):
160+
return self.model(x)
161+
162+
class IPAdapterModel(torch.nn.Module):
163+
def __init__(self, input_dim=512, output_dim=10): # Example for 10 classes
164+
super().__init__()
165+
self.model = torch.nn.Linear(input_dim, output_dim)
166+
167+
def forward(self, x):
168+
return self.model(x)
169+
170+
# Initialize models
171+
image_proj_model = ImageProjectionModel()
172+
ip_adapter_model = IPAdapterModel()
173+
174+
# Load weights into models
175+
image_proj_model.load_state_dict(image_proj_sd)
176+
ip_adapter_model.load_state_dict(ip_adapter_sd)
177+
178+
# Set models to evaluation mode
179+
image_proj_model.eval()
180+
ip_adapter_model.eval()
181+
182+
#Inference pipeline
183+
def inference(image_tensor):
184+
"""
185+
Run inference using the loaded models.
186+
187+
Args:
188+
image_tensor: Preprocessed image tensor from CLIPProcessor
189+
190+
Returns:
191+
Final inference results
192+
"""
193+
with torch.no_grad():
194+
# Step 1: Project the image features
195+
image_proj = image_proj_model(image_tensor)
196+
197+
# Step 2: Pass the projected features through the IP Adapter
198+
result = ip_adapter_model(image_proj)
199+
200+
return result
201+
202+
# Using CLIP for image preprocessing
203+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
204+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
205+
206+
#Image file path
207+
image_path = "path/to/image.jpg"
208+
209+
# Preprocess the image
210+
inputs = processor(images=image_path, return_tensors="pt")
211+
image_features = clip_model.get_image_features(inputs["pixel_values"])
212+
213+
# Normalize the image features as per CLIP's recommendations
214+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
215+
216+
# Run inference
217+
output = inference(image_features)
218+
print("Inference output:", output)
219+
```
220+
221+
#### Parameters:
222+
- `ckpt`: Path to the trained model checkpoint file.
223+
- `map_location="cpu"`: Specifies that the model should be loaded onto the CPU.
224+
- `image_proj_sd`: Dictionary to store the components related to image projection.
225+
- `ip_sd`: Dictionary to store the components related to the IP adapter.
226+
- `"unet"`, `"image_proj_model"`, `"adapter_modules"`: Prefixes indicating components of the model.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
accelerate
2+
torchvision
3+
transformers>=4.25.1
4+
ip_adapter

0 commit comments

Comments
 (0)