Skip to content

Commit 5c3f99a

Browse files
author
LRL-ModelCloud
committed
add granite support
1 parent bd8d07e commit 5c3f99a

File tree

8 files changed

+16
-10
lines changed

8 files changed

+16
-10
lines changed

gptqmodel/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .gpt_bigcode import GPTBigCodeGPTQ
1717
from .gpt_neox import GPTNeoXGPTQ
1818
from .gptj import GPTJGPTQ
19+
from .granite import GraniteGPTQ
1920
from .grinmoe import GrinMOEGPTQ
2021
from .internlm import InternLMGPTQ
2122
from .internlm2 import InternLM2GPTQ
@@ -24,6 +25,7 @@
2425
from .minicpm3 import MiniCPM3GPTQ
2526
from .mistral import MistralGPTQ
2627
from .mixtral import MixtralGPTQ
28+
from .mllama import MLlamaGPTQ
2729
from .moss import MOSSGPTQ
2830
from .mpt import MPTGPTQ
2931
from .opt import OPTGPTQ
@@ -37,4 +39,3 @@
3739
from .starcoder2 import Starcoder2GPTQ
3840
from .xverse import XverseGPTQ
3941
from .yi import YiGPTQ
40-
from .mllama import MLlamaGPTQ

gptqmodel/models/_const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def get_device_by_type(type_value: str):
5858
"exaone",
5959
"grinmoe",
6060
"mllama",
61+
"granite",
6162
]
6263

6364
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048

gptqmodel/models/auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .gpt_bigcode import GPTBigCodeGPTQ
2222
from .gpt_neox import GPTNeoXGPTQ
2323
from .gptj import GPTJGPTQ
24+
from .granite import GraniteGPTQ
2425
from .grinmoe import GrinMOEGPTQ
2526
from .internlm import InternLMGPTQ
2627
from .internlm2 import InternLM2GPTQ
@@ -87,6 +88,7 @@
8788
"exaone": ExaoneGPTQ,
8889
"grinmoe": GrinMOEGPTQ,
8990
"mllama": MLlamaGPTQ,
91+
"granite": GraniteGPTQ,
9092
}
9193

9294

gptqmodel/models/mllama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from .base import BaseGPTQModel
21
from transformers import AutoModelForPreTraining
32

3+
from .base import BaseGPTQModel
4+
5+
46
# TODO FIXME: we currently do not support quantizing cross attention layer (pixel_values)
57
class MLlamaGPTQ(BaseGPTQModel):
68
# AutoModelForPreTraining return a correct MLlamaForConditionalGeneration for mllama.

gptqmodel/nn_modules/qlinear/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Optional
1+
from typing import Optional, Tuple
22

33
import torch.nn as nn
44

gptqmodel/nn_modules/qlinear/qlinear_qbits.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
def qbits_dtype() -> torch.dtype:
2424
try:
2525
from intel_extension_for_transformers import qbits
26-
except Exception as e:
26+
except Exception:
2727
raise ImportError("intel_extension_for_transformers not installed. "
2828
"Please install via via 'pip install intel_extension_for_transformers")
2929

@@ -112,7 +112,7 @@ def post_init(self, quantize_config):
112112

113113
try:
114114
from intel_extension_for_transformers import qbits
115-
except Exception as e:
115+
except Exception:
116116
raise ImportError("intel_extension_for_transformers not installed. "
117117
"Please install via via 'pip install intel_extension_for_transformers")
118118

@@ -257,7 +257,7 @@ def pack(self, linear, scales, zeros, g_idx=None):
257257
def forward(self, x: torch.Tensor):
258258
try:
259259
from intel_extension_for_transformers import qbits
260-
except Exception as e:
260+
except Exception:
261261
raise ImportError("intel_extension_for_transformers not installed. "
262262
"Please install via via 'pip install intel_extension_for_transformers")
263263

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@
110110
extra_compile_args = {
111111
"cxx": [
112112
"-O3",
113-
"-std=c++17",
114-
"-fopenmp",
115-
"-lgomp",
113+
"-std=c++17",
114+
"-fopenmp",
115+
"-lgomp",
116116
"-DENABLE_BF16"
117117
"-Wno-switch-bool",
118118
],

tests/test_sharded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import unittest # noqa: E402
1212

1313
import torch # noqa: E402
14-
from gptqmodel import BACKEND, GPTQModel # noqa: E402
14+
from gptqmodel import GPTQModel # noqa: E402
1515
from transformers import AutoTokenizer # noqa: E402
1616

1717

0 commit comments

Comments
 (0)