diff --git a/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py b/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py index 94788331f..5ed64b17a 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py @@ -92,6 +92,8 @@ class BitBLASQuantLinear(BaseQuantLinear): torch.half: "float16", torch.int8: "int8", } + # for transformers/optimum tests compat + QUANT_TYPE = "bitblas" def __init__( self, diff --git a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py index a7684bc72..bf6a84c13 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py @@ -23,6 +23,8 @@ class CudaQuantLinear(TorchQuantLinear): SUPPORTS_BITS = [2, 3, 4, 8] SUPPORTS_DEVICES = [DEVICE.CUDA] + # for transformers/optimum tests compat + QUANT_TYPE = "cuda" def __init__( self, diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index a439365a5..7ca5ccf6f 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -45,6 +45,8 @@ class ExllamaQuantLinear(BaseQuantLinear): SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_DEVICES = [DEVICE.CUDA] + # for transformers/optimum tests compat + QUANT_TYPE = "exllama" """Linear layer implementation with per-group 4-bit quantization of the weights""" diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index 9d18024c5..45d6a99ad 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -107,6 +107,8 @@ class ExllamaV2QuantLinear(BaseQuantLinear): SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_DEVICES = [DEVICE.CUDA] + # for transformers/optimum tests compat + QUANT_TYPE = "exllamav2" """Linear layer implementation with per-group 4-bit quantization of the weights""" diff --git a/gptqmodel/nn_modules/qlinear/qlinear_ipex.py b/gptqmodel/nn_modules/qlinear/qlinear_ipex.py index 7e289e64b..99e888b93 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_ipex.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_ipex.py @@ -51,6 +51,8 @@ def convert_dtype_torch2str(dtype): class IPEXQuantLinear(BaseQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU] + # for transformers/optimum tests compat + QUANT_TYPE = "ipex" def __init__( self, diff --git a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py index b3500629f..a8c80cf73 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py @@ -145,6 +145,8 @@ class MarlinQuantLinear(BaseQuantLinear): SUPPORTS_SYM = [True] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [64] SUPPORTS_DEVICES = [DEVICE.CUDA] + # for transformers/optimum tests compat + QUANT_TYPE = "marlin" def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, bias: bool, **kwargs): diff --git a/gptqmodel/nn_modules/qlinear/qlinear_torch.py b/gptqmodel/nn_modules/qlinear/qlinear_torch.py index a52437fb6..95ecd2091 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_torch.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_torch.py @@ -15,6 +15,8 @@ class TorchQuantLinear(BaseQuantLinear): SUPPORTS_BITS = [2, 3, 4, 8] SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU, DEVICE.CUDA] + # for transformers/optimum tests compat + QUANT_TYPE = "torch" def __init__( self, diff --git a/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py b/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py index 0927c49ca..85d062c70 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py @@ -31,6 +31,8 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin): SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_DEVICES = [DEVICE.CUDA] + # for transformers/optimum tests compat + QUANT_TYPE = "tritonv2" """ Triton v2 quantized linear layer.