diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index e81ad7883..b0f9a84f2 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -691,7 +691,8 @@ def save_quantized( state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} model_save_name = model_base_name + ".safetensors" else: - model_save_name = model_base_name + ".bin" + model_save_name = model_base_name + ".pt" + if not self.qlinear_kernel.SUPPORTS_SHARDS and max_shard_size is not None: logger.warning("Sharding is not supported for this quant. Disabling sharding.") max_shard_size = None @@ -1106,7 +1107,6 @@ def from_quantized( use_safetensors: bool = True, trust_remote_code: bool = False, format: Optional[FORMAT] = None, - allow_unsafe_loading: bool = False, verify_hash: Optional[Union[str, List[str]]] = None, **kwargs, ): @@ -1247,7 +1247,7 @@ def from_quantized( if use_safetensors: extensions.append(".safetensors") else: - extensions += [".bin", ".pt"] + extensions += [".pt", ".pth"] model_name_or_path = str(model_name_or_path) @@ -1260,14 +1260,9 @@ def from_quantized( # bin files have security issues: disable loading by default if ".bin" in resolved_archive_file: - if allow_unsafe_loading: - logger.warning( - "There are security risks when loading tensors from .bin files. Make sure you are loading model only from a trusted source." - ) - else: - raise ValueError( - "Loading of unsafe .bin files are not allowed by default. Pass allow_unsafe_loading=True to bypass." - ) + raise ValueError( + "Loading of .bin files are not allowed due to safety. Please convert your model to safetensor or pytorch format." + ) quantize_config.runtime_format = quantize_config.format diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 87225f1da..852ec3af4 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -543,7 +543,7 @@ def get_checkpoints( model_name_or_path: str, extensions: List[str], **cached_file_kwargs ): """ - Retrives (and if necessary downloads from Hugging Face Hub) the model checkpoint. Sharding is supported. All the `possible_model_basenames` (e.g. `["model", "model-4bit-gptq"]`) will be explored over all `extensions` (e.g. `[".bin", ".safetensors"]`). + Retrives (and if necessary downloads from Hugging Face Hub) the model checkpoint. Sharding is supported. All the `possible_model_basenames` (e.g. `["model", "model-4bit-gptq"]`) will be explored over all `extensions` (e.g. `[".safetensors"]`). """ searched_files = [] resolved_archive_file = None diff --git a/tests/test_pt.py b/tests/test_pt.py new file mode 100644 index 000000000..52294f42f --- /dev/null +++ b/tests/test_pt.py @@ -0,0 +1,37 @@ +import torch +import unittest + +from transformers import AutoTokenizer + +from gptqmodel import GPTQModel, QuantizeConfig + +pretrained_model_id = "facebook/opt-125m" +quantized_model_id = "facebook-opt-125m" + +class Test_save_load_pt_weight(unittest.TestCase): + def test_pt(self): + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_id, use_fast=True) + calibration_dataset = [ + tokenizer( + "gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm." + ) + ] + + reference_output = "gptqmodel is an easy-to-use model for creating a variety of a variety" + + quantize_config = QuantizeConfig( + bits=4, + group_size=128, + ) + + model = GPTQModel.from_pretrained(pretrained_model_id, quantize_config) + + model.quantize(calibration_dataset) + + model.save_quantized(quantized_model_id, use_safetensors=False) + + model = GPTQModel.from_quantized(quantized_model_id, device="cuda:0", use_safetensors=False) + + result = tokenizer.decode(model.generate(**tokenizer("gptqmodel is an easy-to-use model", return_tensors="pt").to(model.device))[0]) + + self.assertEqual(result, reference_output)