Skip to content

Commit 2587e1a

Browse files
[FIX] gptq v2 load (#724)
* transformers Signed-off-by: jiqing-feng <[email protected]> * add hf_select_quant_layer Signed-off-by: jiqing-feng <[email protected]> * add transformers inference example Signed-off-by: jiqing-feng <[email protected]> * add unittest --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: jiqing-feng <[email protected]>
1 parent d8a802e commit 2587e1a

File tree

5 files changed

+133
-17
lines changed

5 files changed

+133
-17
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
2+
3+
tokenizer = AutoTokenizer.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ")
4+
quantized_model = AutoModelForCausalLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ")
5+
print(tokenizer.decode(quantized_model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(quantized_model.device))[0]))
6+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
2+
3+
model_id = "facebook/opt-125m"
4+
tokenizer = AutoTokenizer.from_pretrained(model_id)
5+
dataset = ["gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
6+
gptq_config = GPTQConfig(bits=4, dataset=dataset, tokenizer=tokenizer)
7+
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", quantization_config=gptq_config)
8+
quantized_model.save_pretrained("./opt-125m-gptq")
9+
tokenizer.save_pretrained("./opt-125m-gptq")
10+
11+
model = AutoModelForCausalLM.from_pretrained("./opt-125m-gptq", device_map="auto")
12+
13+
print(tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0]))

gptqmodel/utils/importer.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,38 +34,63 @@
3434
}
3535

3636

37+
def hf_select_quant_linear(
38+
bits: int,
39+
group_size: int,
40+
desc_act: bool,
41+
sym: bool,
42+
backend: BACKEND = BACKEND.AUTO,
43+
format: FORMAT = FORMAT.GPTQ,
44+
pack: bool = False,
45+
dynamic=None,
46+
):
47+
return select_quant_linear(
48+
bits=bits,
49+
group_size=group_size,
50+
desc_act=desc_act,
51+
sym=sym,
52+
backend=backend,
53+
format=format,
54+
pack=pack,
55+
dynamic=dynamic,
56+
)
57+
58+
3759
# auto select the correct/optimal QuantLinear class
3860
def select_quant_linear(
3961
bits: int,
4062
group_size: int,
4163
desc_act: bool,
4264
sym: bool,
43-
backend: BACKEND,
44-
format: FORMAT,
65+
backend: BACKEND = BACKEND.AUTO,
66+
format: FORMAT = FORMAT.GPTQ,
4567
pack: bool = False,
4668
dynamic=None,
4769
):
4870
# Handle the case where backend is AUTO.
4971
if backend == BACKEND.AUTO:
50-
allow_backends = format_dict[format]
51-
err = None
52-
for k, values in backend_dict.items():
72+
if not torch.cuda.is_available():
73+
backend = BACKEND.IPEX
74+
else:
75+
allow_backends = format_dict[format]
76+
err = None
77+
for k, values in backend_dict.items():
5378

54-
for v in values:
55-
in_allow_backends = k in allow_backends
56-
validate, err = v.validate(bits, group_size, desc_act, sym, dynamic=dynamic)
57-
if in_allow_backends and validate:
58-
if pack:
59-
check_pack_func = hasattr(v, "pack")
60-
if check_pack_func:
79+
for v in values:
80+
in_allow_backends = k in allow_backends
81+
validate, err = v.validate(bits, group_size, desc_act, sym, dynamic=dynamic)
82+
if in_allow_backends and validate:
83+
if pack:
84+
check_pack_func = hasattr(v, "pack")
85+
if check_pack_func:
86+
logger.info(f"Auto choose the fastest one based on quant model compatibility: {v}")
87+
return v
88+
else:
6189
logger.info(f"Auto choose the fastest one based on quant model compatibility: {v}")
6290
return v
63-
else:
64-
logger.info(f"Auto choose the fastest one based on quant model compatibility: {v}")
65-
return v
6691

67-
if err:
68-
raise err
92+
if err:
93+
raise err
6994

7095
# Handle the case where backend is not AUTO.
7196
if backend == BACKEND.TRITON:

gptqmodel/utils/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformers.utils.hub import cached_file
2222

2323
from .backend import BACKEND
24+
from .exllama import exllama_set_max_input_length
2425
from .importer import select_quant_linear
2526
from .logger import setup_logger
2627
from .progress import ProgressBar
@@ -536,6 +537,9 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon
536537

537538
torch.cuda.empty_cache()
538539

540+
# if use_act_order and max_input_length and isinstance(submodule, ExllamaQuantLinear):
541+
# model = exllama_set_max_input_length(model, max_input_length)
542+
539543
return model
540544

541545

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import unittest
2+
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
3+
4+
5+
class TestTransformersIntegration(unittest.TestCase):
6+
7+
@classmethod
8+
def setUpClass(self):
9+
pass
10+
11+
def _test_load_quantized_model_gptq_v1(self, device_map):
12+
model_id_or_path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"
13+
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)
14+
quantized_model = AutoModelForCausalLM.from_pretrained(model_id_or_path,
15+
device_map=device_map,)
16+
generate_str = tokenizer.decode(quantized_model.generate(**tokenizer("The capital of France is is", return_tensors="pt").to(quantized_model.device))[0])
17+
expect_str = "<s> The capital of France is is Paris.\nThe capital of France is Paris.\nThe capital of France is Paris.\nThe capital of France is Paris.\nThe capital of France is"
18+
self.assertEqual(generate_str[:50], expect_str[:50])
19+
20+
def _test_load_quantized_model_gptq_v2(self, device_map):
21+
model_id_or_path = "/monster/data/model/opt-125m/quant/2024-12-02_13-28-10_subcircularly_autogptq_version_pr640_bit4_group128_seq2048_batch16/damp0.1_descTrue_gptq_v2_symTrue_pack_dataFalse_mseTrue_mse_norm2.4_mse_grid100_mse_maxshrink0.8/c40_gr0_dic0_sen0_det0_rate0_native0_lm_compression1024_text_reduction0/opt_125m_gptqv2"
22+
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path)
23+
quantized_model = AutoModelForCausalLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
24+
device_map=device_map,)
25+
generate_str = tokenizer.decode(quantized_model.generate(**tokenizer("The capital of France is is", return_tensors="pt").to(quantized_model.device))[0])
26+
expect_str = "</s>The capital of France is is found velvetJustice ten for bowel Tuesday"
27+
28+
self.assertEqual(generate_str[:len(expect_str)], expect_str)
29+
30+
def _test_quantize(self, device_map):
31+
model_id = "facebook/opt-125m"
32+
tokenizer = AutoTokenizer.from_pretrained(model_id)
33+
dataset = [
34+
"gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
35+
gptq_config = GPTQConfig(bits=4, dataset=dataset, tokenizer=tokenizer)
36+
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map,
37+
quantization_config=gptq_config)
38+
quantized_model.save_pretrained("./opt-125m-gptq")
39+
tokenizer.save_pretrained("./opt-125m-gptq")
40+
41+
model = AutoModelForCausalLM.from_pretrained("./opt-125m-gptq", device_map=device_map)
42+
43+
generate_str = tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0])
44+
45+
expect_str = "</s>gptqmodel is a good way to get a good way for a good way for a good way."
46+
47+
print('generate_str',generate_str)
48+
print('expect_str',expect_str)
49+
50+
self.assertEqual(generate_str[:40], expect_str[:40])
51+
52+
def test_load_quantized_model_gptq_v1_ipex(self):
53+
self._test_load_quantized_model_gptq_v1(device_map="cpu")
54+
55+
def test_load_quantized_model_gptq_v1_cuda(self):
56+
self._test_load_quantized_model_gptq_v1(device_map="cuda")
57+
58+
def test_load_quantized_model_gptq_v2_ipex(self):
59+
self._test_load_quantized_model_gptq_v2(device_map="cpu")
60+
61+
def test_load_quantized_model_gptq_v2_cuda(self):
62+
self._test_load_quantized_model_gptq_v2(device_map="cuda")
63+
64+
def test_quantize_ipex(self):
65+
self._test_quantize(device_map="cpu")
66+
67+
def test_quantize_cuda(self):
68+
self._test_quantize(device_map="cuda")

0 commit comments

Comments
 (0)