Pure Rust implementation for training modern diffusion models with GPU acceleration.
- SDXL - Stable Diffusion XL 1.0
- SD 3.5 - Stable Diffusion 3.5 (Medium/Large/Large-Turbo)
- Flux - Black Forest Labs Flux (Dev/Schnell)
Image Models:
- Flex - Next-gen architecture
- OmniGen 2 - Multi-modal generation
- HiDream - High-resolution synthesis
- Chroma - Advanced color model
- Sana - Efficient transformer
- Kolors - Bilingual diffusion model
Video Models:
- Wan Vace 2.1 - Video generation
- LTX - Long-form video synthesis
- Hunyuan - Multi-modal video model
- ✅ LoRA Training: Low-rank adaptation for all supported models
- ✅ Var-based Training: Direct gradient tracking without VarBuilder limitations
- ✅ GPU-Only: Industry-standard GPU requirement (no CPU fallback)
- ✅ ComfyUI Compatible: Saves LoRA weights in ComfyUI format
- ✅ Memory Optimized: Designed for 24GB VRAM with gradient checkpointing
- ✅ Integrated Sampling: Generate samples during training to monitor progress
- ✅ 8-bit Adam: Memory-efficient optimizer
- ✅ Mixed Precision: BF16/FP16 training support
- 🚧 Full Finetune: Complete model fine-tuning (not just LoRA)
- 🚧 DoRA: Weight-Decomposed Low-Rank Adaptation
- 🚧 LoKr: Low-rank Kronecker product adaptation
- 🚧 Multi-GPU: Distributed training support
- 🚧 FSDP: Fully Sharded Data Parallel training
- 🚧 Flash Attention 3: Latest attention optimizations
- CUDA-capable GPU (required - no CPU training support)
- 24GB+ VRAM recommended
- CUDA 11.0 or higher
- Rust 1.70+
- Trainable-Candle fork (required for training support)
This project requires the Trainable-Candle fork from https://github.com/CodeAlexx/Trainable-Candle which provides:
- GPU-accelerated LoRA backward pass with cuBLAS
- Direct Var creation for training (bypasses VarBuilder limitations)
- Training-enabled Candle without the inference-only restrictions
- Clone both repositories:
# Clone Trainable-Candle fork (required)
git clone https://github.com/CodeAlexx/Trainable-Candle.git
# Clone EriDiffusion
git clone https://github.com/CodeAlexx/EriDiffusion.git
cd EriDiffusion
- Update Cargo.toml to point to your local Trainable-Candle:
[dependencies]
candle-core = { path = "../Trainable-Candle/candle-core", features = ["cuda", "cuda-backward"] }
candle-nn = { path = "../Trainable-Candle/candle-nn" }
candle-transformers = { path = "../Trainable-Candle/candle-transformers" }
- Build the project:
cargo build --release --features cuda-backward
- The executable will be at
target/release/trainer
. Copy it to your PATH or project root:
# Option 1: Copy to project root
cp target/release/trainer .
# Option 2: Install to system (requires sudo)
sudo cp target/release/trainer /usr/local/bin/
EriDiffusion uses a single trainer
binary that automatically detects the model type from your YAML configuration:
# After building, run from project root:
./trainer config/sdxl_lora_24gb_optimized.yaml
./trainer config/sd35_lora_training.yaml
./trainer config/flux_lora_24gb.yaml
# Or with full path:
trainer /path/to/config/sdxl_lora_24gb_optimized.yaml
The trainer reads the model architecture from the YAML and automatically routes to the correct training pipeline.
Each model has its own config file with model-specific settings:
- Model paths (must be local .safetensors files)
- Dataset location
- Training parameters
- LoRA rank and alpha
- Sampling settings
model:
name_or_path: "/path/to/sdxl_model.safetensors"
is_sdxl: true
network:
type: "lora"
linear: 16 # LoRA rank
linear_alpha: 16
train:
batch_size: 1
steps: 2000
gradient_accumulation: 4
lr: 1e-4
optimizer: "adamw8bit"
gradient_checkpointing: true
- LoRA weights saved to
output/[model_name]/checkpoints/
- Sample images saved to
output/[model_name]/samples/
- All outputs are ComfyUI-compatible
With default settings on 24GB GPU:
- Batch size 1: ~18-20GB
- With gradient checkpointing: ~16-18GB
- Higher resolutions (1024x1024) may require VAE tiling
- SDXL: U-Net based with dual text encoders (CLIP-L + CLIP-G)
- SD 3.5: MMDiT (Multimodal Diffusion Transformer) with triple text encoding
- Flux: Hybrid architecture with double/single stream blocks
All models use the Trainable-Candle fork which enables:
- Direct
Var::from_tensor()
for trainable parameters (no VarBuilder) - GPU-accelerated LoRA backward pass with cuBLAS
- Gradient tracking throughout the entire model
- Direct safetensors loading without inference-only limitations
Standard Candle's VarBuilder returns immutable Tensor
objects, making training impossible. The Trainable-Candle fork bypasses this entirely, allowing us to create trainable Var
objects directly and implement proper backpropagation.
- ✅ LoRA training for SDXL, SD 3.5, Flux
- ✅ Basic sampling during training
- ✅ Memory optimizations for 24GB GPUs
- 🚧 Full model fine-tuning support
- 🚧 Complete sampling for all models
- 🚧 Additional model architectures
- 📋 Video model support (Wan Vace 2.1, LTX, Hunyuan)
- 📋 Multi-GPU distributed training
- 📋 Advanced adaptation methods (DoRA, LoKr)
See CLAUDE.md for detailed development guidelines and model specifications.
MIT OR Apache-2.0