-
Notifications
You must be signed in to change notification settings - Fork 6k
[core] TorchAO Quantizer #10009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] TorchAO Quantizer #10009
Changes from 14 commits
64cbf11
b78a36c
355509e
cbb0da4
ee084a5
748a002
bc006f2
956f3bf
2c6beef
cfdb94f
8e214e2
1d9f832
01b2b42
b17cf35
250ccf4
50946a9
edae34b
8f09bdf
7c79b8e
820ac88
f9f1535
747bd7d
25d3cf8
10deb16
f3771a8
55d6155
de97a51
101d10c
edd98db
2677e0c
cc70887
5f75db2
9704daa
b227189
7d9d1dc
e9fccb6
bc874fc
29ec905
7ca64fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,31 @@ | ||||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||||||
|
||||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||||||
the License. You may obtain a copy of the License at | ||||||
|
||||||
http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|
||||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||||||
specific language governing permissions and limitations under the License. --> | ||||||
|
||||||
# torchao | ||||||
|
||||||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks). | ||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Before you begin, make sure you have Pytorch version 2.5, or above, and TorchAO installed: | ||||||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
```bash | ||||||
pip install -U torch torchao | ||||||
``` | ||||||
|
||||||
## Usage | ||||||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
## Usage | ||||||
|
||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
## Resources | ||||||
|
||||||
- [TorchAO Quantization API]() | ||||||
a-r-r-o-w marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao) |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -671,10 +671,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
hf_quantizer = None | ||||
|
||||
if hf_quantizer is not None: | ||||
if device_map is not None: | ||||
raise NotImplementedError( | ||||
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." | ||||
) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sayakpaul I'm not sure how this impacts BnB quantizer. I assume it was disabled for BnB for some reason I'm not aware of. It works with TorchAO as expected though so if you need this to have some kind of guard for torchao-specific, I'll add it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @SunMarc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How did you test that?
That is because we merge the sharded checkpoints when using
This is not hit when loading quantized checkpoints at least for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There is a test for this in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah but this is about custom user-provided There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check of I would like to know if there should be an error raised if BnB quantizer is the method used. Something like: if quantization method is BnB and device_map is not None:
raise Error Does that work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah that works for me. |
||||
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) | ||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) | ||||
|
||||
|
@@ -829,13 +825,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
if device_map is None and not is_sharded: | ||||
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. | ||||
# It would error out during the `validate_environment()` call above in the absence of cuda. | ||||
is_quant_method_bnb = ( | ||||
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES | ||||
) | ||||
if hf_quantizer is None: | ||||
param_device = "cpu" | ||||
# TODO (sayakpaul, SunMarc): remove this after model loading refactor | ||||
elif is_quant_method_bnb: | ||||
else: | ||||
param_device = torch.cuda.current_device() | ||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
state_dict = load_state_dict(model_file, variant=variant) | ||||
model._convert_deprecated_attention_blocks(state_dict) | ||||
|
Uh oh!
There was an error while loading. Please reload this page.