Skip to content

Commit 41ba8c0

Browse files
authored
Add support for sharded models when TorchAO quantization is enabled (#10256)
* add sharded + device_map check
1 parent 3191248 commit 41ba8c0

File tree

2 files changed

+48
-24
lines changed

2 files changed

+48
-24
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
802802
revision=revision,
803803
subfolder=subfolder or "",
804804
)
805-
if hf_quantizer is not None:
805+
if hf_quantizer is not None and is_bnb_quantization_method:
806806
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
807807
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
808808
is_sharded = False

tests/quantization/torchao/test_torchao.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,14 @@ def test_int4wo_quant_bfloat16_conversion(self):
278278
self.assertEqual(weight.quant_max, 15)
279279
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
280280

281-
def test_offload(self):
281+
def test_device_map(self):
282282
"""
283-
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
284-
that the device map is correctly set (in the `hf_device_map` attribute of the model).
283+
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
284+
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
285+
correctly set (in the `hf_device_map` attribute of the model).
285286
"""
286287

287-
device_map_offload = {
288+
custom_device_map_dict = {
288289
"time_text_embed": torch_device,
289290
"context_embedder": torch_device,
290291
"x_embedder": torch_device,
@@ -293,27 +294,50 @@ def test_offload(self):
293294
"norm_out": torch_device,
294295
"proj_out": "cpu",
295296
}
297+
device_maps = ["auto", custom_device_map_dict]
296298

297299
inputs = self.get_dummy_tensor_inputs(torch_device)
298-
299-
with tempfile.TemporaryDirectory() as offload_folder:
300-
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
301-
quantized_model = FluxTransformer2DModel.from_pretrained(
302-
"hf-internal-testing/tiny-flux-pipe",
303-
subfolder="transformer",
304-
quantization_config=quantization_config,
305-
device_map=device_map_offload,
306-
torch_dtype=torch.bfloat16,
307-
offload_folder=offload_folder,
308-
)
309-
310-
self.assertTrue(quantized_model.hf_device_map == device_map_offload)
311-
312-
output = quantized_model(**inputs)[0]
313-
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
314-
315-
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
316-
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
300+
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
301+
302+
for device_map in device_maps:
303+
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
304+
305+
# Test non-sharded model
306+
with tempfile.TemporaryDirectory() as offload_folder:
307+
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
308+
quantized_model = FluxTransformer2DModel.from_pretrained(
309+
"hf-internal-testing/tiny-flux-pipe",
310+
subfolder="transformer",
311+
quantization_config=quantization_config,
312+
device_map=device_map,
313+
torch_dtype=torch.bfloat16,
314+
offload_folder=offload_folder,
315+
)
316+
317+
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
318+
319+
output = quantized_model(**inputs)[0]
320+
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
321+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
322+
323+
# Test sharded model
324+
with tempfile.TemporaryDirectory() as offload_folder:
325+
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
326+
quantized_model = FluxTransformer2DModel.from_pretrained(
327+
"hf-internal-testing/tiny-flux-sharded",
328+
subfolder="transformer",
329+
quantization_config=quantization_config,
330+
device_map=device_map,
331+
torch_dtype=torch.bfloat16,
332+
offload_folder=offload_folder,
333+
)
334+
335+
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
336+
337+
output = quantized_model(**inputs)[0]
338+
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
339+
340+
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
317341

318342
def test_modules_to_not_convert(self):
319343
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])

0 commit comments

Comments
 (0)