-
Notifications
You must be signed in to change notification settings - Fork 6k
[core] TorchAO Quantizer #10009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] TorchAO Quantizer #10009
Changes from all commits
64cbf11
b78a36c
355509e
cbb0da4
ee084a5
748a002
bc006f2
956f3bf
2c6beef
cfdb94f
8e214e2
1d9f832
01b2b42
b17cf35
250ccf4
50946a9
edae34b
8f09bdf
7c79b8e
820ac88
f9f1535
747bd7d
25d3cf8
10deb16
f3771a8
55d6155
de97a51
101d10c
edd98db
2677e0c
cc70887
5f75db2
9704daa
b227189
7d9d1dc
e9fccb6
bc874fc
29ec905
7ca64fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,92 @@ | ||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||||
the License. You may obtain a copy of the License at | ||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0 | ||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||||
specific language governing permissions and limitations under the License. --> | ||||
|
||||
# torchao | ||||
|
||||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more. | ||||
|
||||
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed it seems Pytorch 2.5+ is required because in
torch.uint1 (and others) which are not available in earlier torch versions. However, diffusers seem to require torch>=1.4 (ref), so this seem inconsistent. Am I missing something?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TorchAO will not be imported or usable unless the pytorch version of 2.5 or above is available. Some Diffusers models can run with the 1.4 version as well, which is why that's the minimum required version. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm running into the same issue with the
Here is the trace, and the pip list:
And the pip list: pip list -v
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks a lot for reporting @fjeremic! We were able to replicate for torch <= 2.2. It seems to not cause the import errors for >= 2.3. We will be doing a patch release soon to fix this behaviour. Sorry for the inconvenience! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for providing a quick fix! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BeckerFelix @fjeremic The patch release is out! Hope it fixes any problems you were facing in torch < 2.3 |
||||
|
||||
```bash | ||||
pip install -U torch torchao | ||||
``` | ||||
|
||||
|
||||
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. | ||||
|
||||
The example below only quantizes the weights to int8. | ||||
|
||||
```python | ||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig | ||||
|
||||
model_id = "black-forest-labs/Flux.1-Dev" | ||||
dtype = torch.bfloat16 | ||||
|
||||
quantization_config = TorchAoConfig("int8wo") | ||||
transformer = FluxTransformer2DModel.from_pretrained( | ||||
model_id, | ||||
subfolder="transformer", | ||||
quantization_config=quantization_config, | ||||
torch_dtype=dtype, | ||||
) | ||||
pipe = FluxPipeline.from_pretrained( | ||||
model_id, | ||||
transformer=transformer, | ||||
torch_dtype=dtype, | ||||
) | ||||
pipe.to("cuda") | ||||
|
||||
prompt = "A cat holding a sign that says hello world" | ||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0] | ||||
image.save("output.png") | ||||
``` | ||||
|
||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code. | ||||
|
||||
```python | ||||
# In the above code, add the following after initializing the transformer | ||||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) | ||||
``` | ||||
|
||||
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware. | ||||
|
||||
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future. | ||||
|
||||
The `TorchAoConfig` class accepts three parameters: | ||||
- `quant_type`: A string value mentioning one of the quantization types below. | ||||
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`. | ||||
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`. | ||||
Comment on lines
+63
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not necessary to have this since it's already in the API docs |
||||
|
||||
## Supported quantization types | ||||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7. | ||||
|
||||
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation. | ||||
|
||||
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly. | ||||
|
||||
The quantization methods supported are as follows: | ||||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
| **Category** | **Full Function Names** | **Shorthands** | | ||||
|--------------|-------------------------|----------------| | ||||
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` | | ||||
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` | | ||||
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` | | ||||
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` | | ||||
|
||||
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations. | ||||
|
||||
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. | ||||
|
||||
## Resources | ||||
|
||||
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) | ||||
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao) |
Uh oh!
There was an error while loading. Please reload this page.