@@ -278,13 +278,14 @@ def test_int4wo_quant_bfloat16_conversion(self):
278
278
self .assertEqual (weight .quant_max , 15 )
279
279
self .assertTrue (isinstance (weight .layout_type , TensorCoreTiledLayoutType ))
280
280
281
- def test_offload (self ):
281
+ def test_device_map (self ):
282
282
"""
283
- Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
284
- that the device map is correctly set (in the `hf_device_map` attribute of the model).
283
+ Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
284
+ The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
285
+ correctly set (in the `hf_device_map` attribute of the model).
285
286
"""
286
287
287
- device_map_offload = {
288
+ custom_device_map_dict = {
288
289
"time_text_embed" : torch_device ,
289
290
"context_embedder" : torch_device ,
290
291
"x_embedder" : torch_device ,
@@ -293,27 +294,50 @@ def test_offload(self):
293
294
"norm_out" : torch_device ,
294
295
"proj_out" : "cpu" ,
295
296
}
297
+ device_maps = ["auto" , custom_device_map_dict ]
296
298
297
299
inputs = self .get_dummy_tensor_inputs (torch_device )
298
-
299
- with tempfile .TemporaryDirectory () as offload_folder :
300
- quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
301
- quantized_model = FluxTransformer2DModel .from_pretrained (
302
- "hf-internal-testing/tiny-flux-pipe" ,
303
- subfolder = "transformer" ,
304
- quantization_config = quantization_config ,
305
- device_map = device_map_offload ,
306
- torch_dtype = torch .bfloat16 ,
307
- offload_folder = offload_folder ,
308
- )
309
-
310
- self .assertTrue (quantized_model .hf_device_map == device_map_offload )
311
-
312
- output = quantized_model (** inputs )[0 ]
313
- output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
314
-
315
- expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
316
- self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
300
+ expected_slice = np .array ([0.3457 , - 0.0366 , 0.0105 , - 0.2275 , - 0.4941 , 0.4395 , - 0.166 , - 0.6641 , 0.4375 ])
301
+
302
+ for device_map in device_maps :
303
+ device_map_to_compare = {"" : 0 } if device_map == "auto" else device_map
304
+
305
+ # Test non-sharded model
306
+ with tempfile .TemporaryDirectory () as offload_folder :
307
+ quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
308
+ quantized_model = FluxTransformer2DModel .from_pretrained (
309
+ "hf-internal-testing/tiny-flux-pipe" ,
310
+ subfolder = "transformer" ,
311
+ quantization_config = quantization_config ,
312
+ device_map = device_map ,
313
+ torch_dtype = torch .bfloat16 ,
314
+ offload_folder = offload_folder ,
315
+ )
316
+
317
+ self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
318
+
319
+ output = quantized_model (** inputs )[0 ]
320
+ output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
321
+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
322
+
323
+ # Test sharded model
324
+ with tempfile .TemporaryDirectory () as offload_folder :
325
+ quantization_config = TorchAoConfig ("int4_weight_only" , group_size = 64 )
326
+ quantized_model = FluxTransformer2DModel .from_pretrained (
327
+ "hf-internal-testing/tiny-flux-sharded" ,
328
+ subfolder = "transformer" ,
329
+ quantization_config = quantization_config ,
330
+ device_map = device_map ,
331
+ torch_dtype = torch .bfloat16 ,
332
+ offload_folder = offload_folder ,
333
+ )
334
+
335
+ self .assertTrue (quantized_model .hf_device_map == device_map_to_compare )
336
+
337
+ output = quantized_model (** inputs )[0 ]
338
+ output_slice = output .flatten ()[- 9 :].detach ().float ().cpu ().numpy ()
339
+
340
+ self .assertTrue (np .allclose (output_slice , expected_slice , atol = 1e-3 , rtol = 1e-3 ))
317
341
318
342
def test_modules_to_not_convert (self ):
319
343
quantization_config = TorchAoConfig ("int8_weight_only" , modules_to_not_convert = ["transformer_blocks.0" ])
0 commit comments