25
25
from transformers .modeling_utils import no_init_weights , shard_checkpoint
26
26
from transformers .models .mllama .modeling_mllama import MllamaCrossAttentionDecoderLayer
27
27
from transformers .utils .generic import ContextManagers
28
+ from .writer import ModelWriter
28
29
29
30
from ..nn_modules .qlinear .qlinear_exllamav2 import ExllamaV2QuantLinear
30
31
from ..nn_modules .qlinear .qlinear_qbits import QBitsQuantLinear , qbits_dtype
@@ -627,264 +628,31 @@ def save_quantized(
627
628
max_shard_size : Optional [str ] = None ,
628
629
model_base_name : Optional [str ] = None
629
630
):
630
- """save quantized model and configs to local disk"""
631
- os .makedirs (save_dir , exist_ok = True )
632
631
633
- pre_quantized_size_mb = get_model_files_size (self .model_name_or_path )
634
- pre_quantized_size_gb = pre_quantized_size_mb / 1024
632
+ checkpoint_file_name = ""
633
+ if hasattr (self , "checkpoint_file_name" ) and self .checkpoint_file_name is not None :
634
+ checkpoint_file_name = self .checkpoint_file_name
635
635
636
- # write gptqmodel tooling fingerprint to config
637
- self .quantize_config .meta_set_versionable (
638
- key = META_FIELD_QUANTIZER ,
639
- value = META_QUANTIZER_GPTQMODEL ,
640
- version = __version__ ,
641
- )
642
-
643
- self .quantize_config .meta_set (
644
- key = META_FIELD_URI ,
645
- value = META_VALUE_URI ,
646
- )
647
-
648
- self .quantize_config .meta_set (
649
- key = META_FIELD_DAMP_PERCENT ,
650
- value = self .quantize_config .damp_percent
651
- )
652
-
653
- self .quantize_config .meta_set (
654
- key = META_FIELD_DAMP_AUTO_INCREMENT ,
655
- value = self .quantize_config .damp_auto_increment
656
- )
657
-
658
- # The config, quantize_config and model may be edited in place in save_quantized.
659
- config = copy .deepcopy (self .model .config )
660
- quantize_config = copy .deepcopy (self .quantize_config )
661
-
662
- if not self .quantized :
663
- raise ValueError ("Save aborted as model is not quantized. Please call `quantize()` first." )
664
-
665
- if model_base_name is None :
666
- model_base_name = (
667
- f"gptq_model-{ self .quantize_config .bits } bit-{ self .quantize_config .group_size } g"
668
- )
669
-
670
- if quantize_config .format == FORMAT .GPTQ_V2 :
671
- logger .warning (
672
- f"Using 'format = { FORMAT .GPTQ_V2 } ': the serialized model is only supported by GPTQModel version >= { MIN_VERSION_WITH_V2 } ."
673
- )
674
-
675
- if not self .load_quantized_model :
676
- model = self .model
677
- # # internal is always gptq v2 but allow users to pass gptq (v1) via config
678
- if quantize_config .format == FORMAT .GPTQ :
679
- # Model qzeros may be edited in place.
680
- model = convert_gptq_v2_to_v1_format (
681
- model , quantize_config = quantize_config , qlinear_kernel = self .qlinear_kernel
682
- )
683
- else :
684
- model = self .get_model_with_quantize (quantize_config , self .model_name_or_path )
685
- model .to (CPU )
686
- state_dict = model .state_dict ()
687
-
688
- model_base_name = "model"
689
-
690
- if use_safetensors :
691
- state_dict = {k : v .clone ().contiguous () for k , v in state_dict .items ()}
692
- model_save_name = model_base_name + ".safetensors"
693
- else :
694
- model_save_name = model_base_name + ".pt"
695
-
696
- if not self .qlinear_kernel .SUPPORTS_SHARDS and max_shard_size is not None :
697
- logger .warning ("Sharding is not supported for this quant. Disabling sharding." )
698
- max_shard_size = None
699
-
700
- if max_shard_size is None :
701
- if use_safetensors :
702
- if safetensors_metadata is None :
703
- safetensors_metadata = {}
704
- elif not isinstance (safetensors_metadata , dict ):
705
- raise TypeError ("safetensors_metadata must be a dictionary." )
706
- else :
707
- logger .debug (f"Received safetensors_metadata: { safetensors_metadata } " )
708
- new_safetensors_metadata = {}
709
- converted_keys = False
710
- for key , value in safetensors_metadata .items ():
711
- if not isinstance (key , str ) or not isinstance (value , str ):
712
- converted_keys = True
713
- try :
714
- new_key = str (key )
715
- new_value = str (value )
716
- except Exception as e :
717
- raise TypeError (
718
- f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: { e } "
719
- )
720
- if new_key in new_safetensors_metadata :
721
- logger .warning (
722
- f"After converting safetensors_metadata keys to strings, the key '{ new_key } ' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
723
- )
724
- new_safetensors_metadata [new_key ] = new_value
725
- safetensors_metadata = new_safetensors_metadata
726
- if converted_keys :
727
- logger .debug (
728
- f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: { safetensors_metadata } "
729
- )
730
-
731
- # Format is required to enable Accelerate to load the metadata
732
- # otherwise it raises an OSError
733
- safetensors_metadata ["format" ] = "pt"
734
- safe_save (state_dict , join (save_dir , model_save_name ), safetensors_metadata )
735
- else :
736
- logger .warning (
737
- "We highly suggest saving quantized model using safetensors format for security reasons. Please set `use_safetensors=True` whenever possible." )
738
- torch .save (model .state_dict (), join (save_dir , model_save_name ))
739
- total_size_mb = os .path .getsize (join (save_dir , model_save_name )) / (1024 * 1024 )
740
- else :
741
- # Shard checkpoint
742
- shards , index = shard_checkpoint (state_dict , max_shard_size = max_shard_size , weights_name = model_save_name )
743
-
744
- # Clean the folder from a previous save
745
- for filename in os .listdir (save_dir ):
746
- full_filename = join (save_dir , filename )
747
-
748
- # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
749
- filename_no_suffix = filename .replace (".bin" , "" ).replace (".safetensors" , "" )
750
- reg = re .compile (r"(.*?)-\d{5}-of-\d{5}" )
751
-
752
- if (
753
- filename .startswith (model_base_name )
754
- and isfile (full_filename )
755
- and filename not in shards .keys ()
756
- and reg .fullmatch (filename_no_suffix ) is not None
757
- ):
758
- os .remove (full_filename )
759
-
760
- total_size_mb = 0
761
- # Save the model
762
- for shard_file , shard in shards .items ():
763
- if use_safetensors :
764
- if safetensors_metadata is None :
765
- safetensors_metadata = {}
766
- elif not isinstance (safetensors_metadata , dict ):
767
- raise TypeError ("safetensors_metadata must be a dictionary." )
768
- else :
769
- logger .debug (f"Received safetensors_metadata: { safetensors_metadata } " )
770
- new_safetensors_metadata = {}
771
- converted_keys = False
772
- for key , value in safetensors_metadata .items ():
773
- if not isinstance (key , str ) or not isinstance (value , str ):
774
- converted_keys = True
775
- try :
776
- new_key = str (key )
777
- new_value = str (value )
778
- except Exception as e :
779
- raise TypeError (
780
- f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: { e } " )
781
- if new_key in new_safetensors_metadata :
782
- logger .warning (
783
- f"After converting safetensors_metadata keys to strings, the key '{ new_key } ' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." )
784
- new_safetensors_metadata [new_key ] = new_value
785
- safetensors_metadata = new_safetensors_metadata
786
- if converted_keys :
787
- logger .debug (
788
- f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: { safetensors_metadata } " )
789
-
790
- # Format is required to enable Accelerate to load the metadata
791
- # otherwise it raises an OSError
792
- safetensors_metadata ["format" ] = "pt"
793
-
794
- safe_save (shard , join (save_dir , shard_file ), safetensors_metadata )
795
- else :
796
- torch .save (shard , join (save_dir , shard_file ))
797
- shard_size_mb = os .path .getsize (join (save_dir , shard_file )) / (1024 * 1024 )
798
- total_size_mb += shard_size_mb
799
-
800
- if index is not None :
801
- index_save_name = model_save_name + ".index.json"
802
- index_save_path = join (save_dir , index_save_name )
803
- # Save the index as well
804
- with open (index_save_path , "w" , encoding = "utf-8" ) as f :
805
- content = json .dumps (index , indent = 2 , sort_keys = True ) + "\n "
806
- f .write (content )
807
-
808
- total_size_gb = total_size_mb / 1024
809
- size_diff_mb = pre_quantized_size_mb - total_size_mb
810
- size_diff_gb = size_diff_mb / 1024
811
- percent_diff = (size_diff_mb / pre_quantized_size_mb ) * 100
812
- logger .info (f"Pre-Quantized model size: { pre_quantized_size_mb :.2f} MB, { pre_quantized_size_gb :.2f} GB" )
813
- logger .info (f"Quantized model size: { total_size_mb :.2f} MB, { total_size_gb :.2f} GB" )
814
- logger .info (f"Size difference: { size_diff_mb :.2f} MB, { size_diff_gb :.2f} GB - { percent_diff :.2f} %" )
815
-
816
- config .quantization_config = quantize_config .to_dict ()
817
- config .save_pretrained (save_dir )
818
-
819
- quantize_config .save_pretrained (save_dir )
820
-
821
- # need to copy .py files for model/tokenizers not yet merged to HF transformers
822
- if self .trust_remote_code :
823
- copy_py_files (save_dir , model_id_or_path = self .model_name_or_path )
824
-
825
- def get_model_with_quantize (self , quantize_config , model_name_or_path ):
826
- config = AutoConfig .from_pretrained (
827
- model_name_or_path ,
828
- trust_remote_code = True ,
829
- )
830
-
831
- def skip (* args , ** kwargs ):
832
- pass
833
-
834
- torch .nn .init .kaiming_uniform_ = skip
835
- torch .nn .init .uniform_ = skip
836
- torch .nn .init .normal_ = skip
837
- transformers .modeling_utils ._init_weights = False
838
- init_contexts = [no_init_weights ()]
839
- with ContextManagers (init_contexts ):
840
- model = self .model_loader .from_config (
841
- config , torch_dtype = torch .float16
842
- )
843
-
844
- if self .dynamic_expert_index is not None :
845
- num_experts = getattr (config , self .dynamic_expert_index )
846
- self .layer_modules = get_moe_layer_modules (layer_modules = self .layer_modules ,
847
- num_experts = num_experts )
848
-
849
- layers = find_layers (model )
850
- ignore_layers = [self .lm_head ] + self .base_modules
851
-
852
- for name in list (layers .keys ()):
853
- # allow loading of quantized lm_head
854
- if quantize_config .lm_head and name == self .lm_head :
855
- continue
856
-
857
- if any (name .startswith (ignore_layer ) for ignore_layer in ignore_layers ) or all (
858
- not name .endswith (ignore_layer ) for sublist in self .layer_modules for ignore_layer in sublist
859
- ):
860
- # log non-lm-head quantizerd layers only
861
- if name is not self .lm_head :
862
- logger .info (f"The layer { name } is not quantized." )
863
- del layers [name ]
864
-
865
- make_quant (
866
- model ,
867
- layers ,
868
- quantize_config .bits ,
869
- quantize_config .group_size ,
870
- backend = BACKEND .AUTO ,
871
- format = quantize_config .format ,
872
- desc_act = quantize_config .desc_act ,
873
- pack = True ,
874
- )
875
- model .tie_weights ()
876
-
877
- accelerate .load_checkpoint_in_model (
878
- model ,
879
- dtype = torch .float16 ,
880
- # This is very hacky but works due to https://github.com/huggingface/accelerate/blob/bd72a5f1a80d5146554458823f8aeda0a9db5297/src/accelerate/utils/modeling.py#L292
881
- checkpoint = self .checkpoint_file_name ,
882
- # device_map=device_map,
883
- # offload_state_dict=True,
884
- # offload_buffers=True,
636
+ ModelWriter .save_quantized (
637
+ self ,
638
+ save_dir = save_dir ,
639
+ use_safetensors = use_safetensors ,
640
+ max_shard_size = max_shard_size ,
641
+ quantized = self .quantized ,
642
+ model_name_or_path = self .model_name_or_path ,
643
+ model = self .model ,
644
+ load_quantized_model = self .load_quantized_model ,
645
+ qlinear_kernel = self .qlinear_kernel ,
646
+ trust_remote_code = self .trust_remote_code ,
647
+ safetensors_metadata = safetensors_metadata ,
648
+ quantize_config = self .quantize_config ,
649
+ dynamic_expert_index = self .dynamic_expert_index ,
650
+ base_modules = self .base_modules ,
651
+ lm_head = self .lm_head ,
652
+ layer_modules = self .layer_modules ,
653
+ checkpoint_file_name = checkpoint_file_name
885
654
)
886
- torch .cuda .empty_cache ()
887
- return model
655
+
888
656
889
657
def save_pretrained (
890
658
self ,
@@ -1040,7 +808,7 @@ def from_quantized(
1040
808
** kwargs ,
1041
809
):
1042
810
1043
- model , quantize_config , qlinear_kernel , load_quantized_model , generate = ModelLoader .from_quantized (
811
+ model , quantize_config , qlinear_kernel , load_quantized_model , generate , checkpoint_file_name = ModelLoader .from_quantized (
1044
812
model_name_or_path = model_name_or_path ,
1045
813
device_map = device_map ,
1046
814
max_memory = max_memory ,
@@ -1064,6 +832,9 @@ def from_quantized(
1064
832
if generate is not None :
1065
833
cls .generate = generate
1066
834
835
+ if checkpoint_file_name is not None :
836
+ cls .checkpoint_file_name = checkpoint_file_name
837
+
1067
838
return cls (
1068
839
model ,
1069
840
quantized = True ,
0 commit comments