Skip to content

Commit 9840771

Browse files
Move save_quantized function into saver.py (#467)
* move save function to ModelSaver * remove useless code * rename to ModelWriter * pass checkpoint_file_name value * clean code * set checkpoint_file_name value * modify pass gptqmodel/models/loader.py variable * fix variable checkpoint_file_name cause error
1 parent 47113a2 commit 9840771

File tree

3 files changed

+361
-257
lines changed

3 files changed

+361
-257
lines changed

gptqmodel/models/base.py

Lines changed: 27 additions & 256 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transformers.modeling_utils import no_init_weights, shard_checkpoint
2626
from transformers.models.mllama.modeling_mllama import MllamaCrossAttentionDecoderLayer
2727
from transformers.utils.generic import ContextManagers
28+
from .writer import ModelWriter
2829

2930
from ..nn_modules.qlinear.qlinear_exllamav2 import ExllamaV2QuantLinear
3031
from ..nn_modules.qlinear.qlinear_qbits import QBitsQuantLinear, qbits_dtype
@@ -627,264 +628,31 @@ def save_quantized(
627628
max_shard_size: Optional[str] = None,
628629
model_base_name: Optional[str] = None
629630
):
630-
"""save quantized model and configs to local disk"""
631-
os.makedirs(save_dir, exist_ok=True)
632631

633-
pre_quantized_size_mb = get_model_files_size(self.model_name_or_path)
634-
pre_quantized_size_gb = pre_quantized_size_mb / 1024
632+
checkpoint_file_name = ""
633+
if hasattr(self, "checkpoint_file_name") and self.checkpoint_file_name is not None:
634+
checkpoint_file_name = self.checkpoint_file_name
635635

636-
# write gptqmodel tooling fingerprint to config
637-
self.quantize_config.meta_set_versionable(
638-
key=META_FIELD_QUANTIZER,
639-
value=META_QUANTIZER_GPTQMODEL,
640-
version=__version__,
641-
)
642-
643-
self.quantize_config.meta_set(
644-
key=META_FIELD_URI,
645-
value=META_VALUE_URI,
646-
)
647-
648-
self.quantize_config.meta_set(
649-
key=META_FIELD_DAMP_PERCENT,
650-
value=self.quantize_config.damp_percent
651-
)
652-
653-
self.quantize_config.meta_set(
654-
key=META_FIELD_DAMP_AUTO_INCREMENT,
655-
value=self.quantize_config.damp_auto_increment
656-
)
657-
658-
# The config, quantize_config and model may be edited in place in save_quantized.
659-
config = copy.deepcopy(self.model.config)
660-
quantize_config = copy.deepcopy(self.quantize_config)
661-
662-
if not self.quantized:
663-
raise ValueError("Save aborted as model is not quantized. Please call `quantize()` first.")
664-
665-
if model_base_name is None:
666-
model_base_name = (
667-
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
668-
)
669-
670-
if quantize_config.format == FORMAT.GPTQ_V2:
671-
logger.warning(
672-
f"Using 'format = {FORMAT.GPTQ_V2}': the serialized model is only supported by GPTQModel version >= {MIN_VERSION_WITH_V2}."
673-
)
674-
675-
if not self.load_quantized_model:
676-
model = self.model
677-
# # internal is always gptq v2 but allow users to pass gptq (v1) via config
678-
if quantize_config.format == FORMAT.GPTQ:
679-
# Model qzeros may be edited in place.
680-
model = convert_gptq_v2_to_v1_format(
681-
model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel
682-
)
683-
else:
684-
model = self.get_model_with_quantize(quantize_config, self.model_name_or_path)
685-
model.to(CPU)
686-
state_dict = model.state_dict()
687-
688-
model_base_name = "model"
689-
690-
if use_safetensors:
691-
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
692-
model_save_name = model_base_name + ".safetensors"
693-
else:
694-
model_save_name = model_base_name + ".pt"
695-
696-
if not self.qlinear_kernel.SUPPORTS_SHARDS and max_shard_size is not None:
697-
logger.warning("Sharding is not supported for this quant. Disabling sharding.")
698-
max_shard_size = None
699-
700-
if max_shard_size is None:
701-
if use_safetensors:
702-
if safetensors_metadata is None:
703-
safetensors_metadata = {}
704-
elif not isinstance(safetensors_metadata, dict):
705-
raise TypeError("safetensors_metadata must be a dictionary.")
706-
else:
707-
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
708-
new_safetensors_metadata = {}
709-
converted_keys = False
710-
for key, value in safetensors_metadata.items():
711-
if not isinstance(key, str) or not isinstance(value, str):
712-
converted_keys = True
713-
try:
714-
new_key = str(key)
715-
new_value = str(value)
716-
except Exception as e:
717-
raise TypeError(
718-
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
719-
)
720-
if new_key in new_safetensors_metadata:
721-
logger.warning(
722-
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
723-
)
724-
new_safetensors_metadata[new_key] = new_value
725-
safetensors_metadata = new_safetensors_metadata
726-
if converted_keys:
727-
logger.debug(
728-
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
729-
)
730-
731-
# Format is required to enable Accelerate to load the metadata
732-
# otherwise it raises an OSError
733-
safetensors_metadata["format"] = "pt"
734-
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
735-
else:
736-
logger.warning(
737-
"We highly suggest saving quantized model using safetensors format for security reasons. Please set `use_safetensors=True` whenever possible.")
738-
torch.save(model.state_dict(), join(save_dir, model_save_name))
739-
total_size_mb = os.path.getsize(join(save_dir, model_save_name)) / (1024 * 1024)
740-
else:
741-
# Shard checkpoint
742-
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name)
743-
744-
# Clean the folder from a previous save
745-
for filename in os.listdir(save_dir):
746-
full_filename = join(save_dir, filename)
747-
748-
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
749-
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
750-
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
751-
752-
if (
753-
filename.startswith(model_base_name)
754-
and isfile(full_filename)
755-
and filename not in shards.keys()
756-
and reg.fullmatch(filename_no_suffix) is not None
757-
):
758-
os.remove(full_filename)
759-
760-
total_size_mb = 0
761-
# Save the model
762-
for shard_file, shard in shards.items():
763-
if use_safetensors:
764-
if safetensors_metadata is None:
765-
safetensors_metadata = {}
766-
elif not isinstance(safetensors_metadata, dict):
767-
raise TypeError("safetensors_metadata must be a dictionary.")
768-
else:
769-
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
770-
new_safetensors_metadata = {}
771-
converted_keys = False
772-
for key, value in safetensors_metadata.items():
773-
if not isinstance(key, str) or not isinstance(value, str):
774-
converted_keys = True
775-
try:
776-
new_key = str(key)
777-
new_value = str(value)
778-
except Exception as e:
779-
raise TypeError(
780-
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
781-
if new_key in new_safetensors_metadata:
782-
logger.warning(
783-
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
784-
new_safetensors_metadata[new_key] = new_value
785-
safetensors_metadata = new_safetensors_metadata
786-
if converted_keys:
787-
logger.debug(
788-
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
789-
790-
# Format is required to enable Accelerate to load the metadata
791-
# otherwise it raises an OSError
792-
safetensors_metadata["format"] = "pt"
793-
794-
safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
795-
else:
796-
torch.save(shard, join(save_dir, shard_file))
797-
shard_size_mb = os.path.getsize(join(save_dir, shard_file)) / (1024 * 1024)
798-
total_size_mb += shard_size_mb
799-
800-
if index is not None:
801-
index_save_name = model_save_name + ".index.json"
802-
index_save_path = join(save_dir, index_save_name)
803-
# Save the index as well
804-
with open(index_save_path, "w", encoding="utf-8") as f:
805-
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
806-
f.write(content)
807-
808-
total_size_gb = total_size_mb / 1024
809-
size_diff_mb = pre_quantized_size_mb - total_size_mb
810-
size_diff_gb = size_diff_mb / 1024
811-
percent_diff = (size_diff_mb / pre_quantized_size_mb) * 100
812-
logger.info(f"Pre-Quantized model size: {pre_quantized_size_mb:.2f}MB, {pre_quantized_size_gb:.2f}GB")
813-
logger.info(f"Quantized model size: {total_size_mb:.2f}MB, {total_size_gb:.2f}GB")
814-
logger.info(f"Size difference: {size_diff_mb:.2f}MB, {size_diff_gb:.2f}GB - {percent_diff:.2f}%")
815-
816-
config.quantization_config = quantize_config.to_dict()
817-
config.save_pretrained(save_dir)
818-
819-
quantize_config.save_pretrained(save_dir)
820-
821-
# need to copy .py files for model/tokenizers not yet merged to HF transformers
822-
if self.trust_remote_code:
823-
copy_py_files(save_dir, model_id_or_path=self.model_name_or_path)
824-
825-
def get_model_with_quantize(self, quantize_config, model_name_or_path):
826-
config = AutoConfig.from_pretrained(
827-
model_name_or_path,
828-
trust_remote_code=True,
829-
)
830-
831-
def skip(*args, **kwargs):
832-
pass
833-
834-
torch.nn.init.kaiming_uniform_ = skip
835-
torch.nn.init.uniform_ = skip
836-
torch.nn.init.normal_ = skip
837-
transformers.modeling_utils._init_weights = False
838-
init_contexts = [no_init_weights()]
839-
with ContextManagers(init_contexts):
840-
model = self.model_loader.from_config(
841-
config, torch_dtype=torch.float16
842-
)
843-
844-
if self.dynamic_expert_index is not None:
845-
num_experts = getattr(config, self.dynamic_expert_index)
846-
self.layer_modules = get_moe_layer_modules(layer_modules=self.layer_modules,
847-
num_experts=num_experts)
848-
849-
layers = find_layers(model)
850-
ignore_layers = [self.lm_head] + self.base_modules
851-
852-
for name in list(layers.keys()):
853-
# allow loading of quantized lm_head
854-
if quantize_config.lm_head and name == self.lm_head:
855-
continue
856-
857-
if any(name.startswith(ignore_layer) for ignore_layer in ignore_layers) or all(
858-
not name.endswith(ignore_layer) for sublist in self.layer_modules for ignore_layer in sublist
859-
):
860-
# log non-lm-head quantizerd layers only
861-
if name is not self.lm_head:
862-
logger.info(f"The layer {name} is not quantized.")
863-
del layers[name]
864-
865-
make_quant(
866-
model,
867-
layers,
868-
quantize_config.bits,
869-
quantize_config.group_size,
870-
backend=BACKEND.AUTO,
871-
format=quantize_config.format,
872-
desc_act=quantize_config.desc_act,
873-
pack=True,
874-
)
875-
model.tie_weights()
876-
877-
accelerate.load_checkpoint_in_model(
878-
model,
879-
dtype=torch.float16,
880-
# This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
881-
checkpoint=self.checkpoint_file_name,
882-
# device_map=device_map,
883-
# offload_state_dict=True,
884-
# offload_buffers=True,
636+
ModelWriter.save_quantized(
637+
self,
638+
save_dir=save_dir,
639+
use_safetensors=use_safetensors,
640+
max_shard_size=max_shard_size,
641+
quantized=self.quantized,
642+
model_name_or_path=self.model_name_or_path,
643+
model=self.model,
644+
load_quantized_model=self.load_quantized_model,
645+
qlinear_kernel=self.qlinear_kernel,
646+
trust_remote_code=self.trust_remote_code,
647+
safetensors_metadata=safetensors_metadata,
648+
quantize_config=self.quantize_config,
649+
dynamic_expert_index=self.dynamic_expert_index,
650+
base_modules=self.base_modules,
651+
lm_head=self.lm_head,
652+
layer_modules=self.layer_modules,
653+
checkpoint_file_name=checkpoint_file_name
885654
)
886-
torch.cuda.empty_cache()
887-
return model
655+
888656

889657
def save_pretrained(
890658
self,
@@ -1040,7 +808,7 @@ def from_quantized(
1040808
**kwargs,
1041809
):
1042810

1043-
model, quantize_config, qlinear_kernel, load_quantized_model, generate = ModelLoader.from_quantized(
811+
model, quantize_config, qlinear_kernel, load_quantized_model, generate, checkpoint_file_name = ModelLoader.from_quantized(
1044812
model_name_or_path=model_name_or_path,
1045813
device_map=device_map,
1046814
max_memory=max_memory,
@@ -1064,6 +832,9 @@ def from_quantized(
1064832
if generate is not None:
1065833
cls.generate = generate
1066834

835+
if checkpoint_file_name is not None:
836+
cls.checkpoint_file_name = checkpoint_file_name
837+
1067838
return cls(
1068839
model,
1069840
quantized=True,

gptqmodel/models/loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ def from_quantized(
251251
quantize_config,
252252
None, # qlinear_kernel
253253
False, # load_quantized_model
254-
cls.generate
254+
cls.generate,
255+
None # return None if is SGLANG or VLLM
255256
)
256257

257258
if quantize_config.format == FORMAT.MARLIN:
@@ -526,4 +527,5 @@ def skip(*args, **kwargs):
526527
qlinear_kernel,
527528
True, # load_quantized_model
528529
None, # return None if not SGLANG or VLLM
530+
model.checkpoint_file_name
529531
)

0 commit comments

Comments
 (0)