Skip to content

auto infer model base name from model files #451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 4 additions & 28 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,6 @@ def save_quantized(

if model_base_name is None:
model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)

Expand All @@ -682,17 +681,11 @@ def save_quantized(
model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel
)
else:
model = self.get_model_with_quantize(quantize_config)
model = self.get_model_with_quantize(quantize_config, self.model_name_or_path)
model.to(CPU)
state_dict = model.state_dict()

if quantize_config.model_file_base_name is None:
if use_safetensors:
model_base_name = "model"
else:
model_base_name = "pytorch_model"
else:
model_base_name = basename(quantize_config.model_file_base_name)
model_base_name = "model"

if use_safetensors:
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
Expand Down Expand Up @@ -822,17 +815,15 @@ def save_quantized(
config.quantization_config = quantize_config.to_dict()
config.save_pretrained(save_dir)

quantize_config.model_name_or_path = save_dir
quantize_config.model_file_base_name = model_base_name
quantize_config.save_pretrained(save_dir)

# need to copy .py files for model/tokenizers not yet merged to HF transformers
if self.trust_remote_code:
copy_py_files(save_dir, model_id_or_path=self.model_name_or_path)

def get_model_with_quantize(self, quantize_config):
def get_model_with_quantize(self, quantize_config, model_name_or_path):
config = AutoConfig.from_pretrained(
quantize_config.model_name_or_path,
model_name_or_path,
trust_remote_code=True,
)

Expand Down Expand Up @@ -1252,19 +1243,6 @@ def from_quantized(
if BITBLAS_AVAILABLE is False:
raise ValueError(BITBLAS_INSTALL_HINT)

if model_basename is None:
if quantize_config.model_file_base_name:
possible_model_basenames = [quantize_config.model_file_base_name]
else:
possible_model_basenames = [
f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g",
"model",
]
else:
possible_model_basenames = [model_basename]

quantize_config.model_name_or_path = model_name_or_path

extensions = []
if use_safetensors:
extensions.append(".safetensors")
Expand All @@ -1277,7 +1255,6 @@ def from_quantized(
is_sharded, resolved_archive_file, true_model_basename = get_checkpoints(
model_name_or_path=model_name_or_path,
extensions=extensions,
possible_model_basenames=possible_model_basenames,
**cached_file_kwargs,
)

Expand All @@ -1292,7 +1269,6 @@ def from_quantized(
"Loading of unsafe .bin files are not allowed by default. Pass allow_unsafe_loading=True to bypass."
)

quantize_config.model_file_base_name = true_model_basename
quantize_config.runtime_format = quantize_config.format

model_save_name = resolved_archive_file # In case a model is sharded, this would be `model.safetensors.index.json` which may later break.
Expand Down
7 changes: 0 additions & 7 deletions gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,6 @@ class QuantizeConfig():
# if OOM, can set to False
parallel_packing: bool = field(default=True)

# TODO: remove
model_name_or_path: Optional[str] = field(default=None)
model_file_base_name: Optional[str] = field(default=None)

# properties that do not directly contributes to quantization or quant inference should be placed in meta
# i.e. quantizer tool (producer) + version, timestamp, entity who made the quant, etc
meta: Optional[Dict] = field(default=None)
Expand Down Expand Up @@ -345,9 +341,6 @@ def to_dict(self):
"sym": self.sym,
"lm_head": self.lm_head,
"true_sequential": self.true_sequential,
# TODO: deprecate?
"model_name_or_path": self.model_name_or_path,
"model_file_base_name": self.model_file_base_name,
QUANT_METHOD_FIELD: self.quant_method,
FORMAT_FIELD_JSON: self.format,
META_FIELD: self.meta,
Expand Down
99 changes: 52 additions & 47 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
import torch.nn as nn
import transformers
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub import HfApi, hf_hub_download, list_repo_files
from tqdm import tqdm
from transformers import AutoConfig, PretrainedConfig
from transformers.utils.hub import cached_file
Expand Down Expand Up @@ -573,7 +573,7 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon


def get_checkpoints(
model_name_or_path: str, extensions: List[str], possible_model_basenames: List[str], **cached_file_kwargs
model_name_or_path: str, extensions: List[str], **cached_file_kwargs
):
"""
Retrives (and if necessary downloads from Hugging Face Hub) the model checkpoint. Sharding is supported. All the `possible_model_basenames` (e.g. `["model", "model-4bit-gptq"]`) will be explored over all `extensions` (e.g. `[".bin", ".safetensors"]`).
Expand All @@ -584,56 +584,61 @@ def get_checkpoints(

if os.path.isdir(model_name_or_path):
for ext in extensions:
for possible_model_basename in possible_model_basenames:
shard_index_name = possible_model_basename + ext + ".index.json"
searched_files.append(shard_index_name)
possible_index_file = os.path.join(model_name_or_path, shard_index_name)
if os.path.isfile(possible_index_file):
# The model is sharded over several checkpoints.
possible_model_basename = possible_index_file.replace(ext + ".index.json", "")
return True, possible_index_file, possible_model_basename
else:
model_save_name = os.path.join(model_name_or_path, possible_model_basename)
searched_files.append(possible_model_basename + ext)
if os.path.isfile(model_save_name + ext):
resolved_archive_file = model_save_name + ext
return False, resolved_archive_file, possible_model_basename
for fileName in os.listdir(model_name_or_path):
if ext in fileName:
shard_index_name = fileName + ".index.json"
searched_files.append(shard_index_name)
possible_index_file = os.path.join(model_name_or_path, shard_index_name)

if os.path.isfile(possible_index_file):
# The model is sharded over several checkpoints.
possible_model_basename = possible_index_file.replace(ext + ".index.json", "")
return True, possible_index_file, possible_model_basename
else:
model_save_name = os.path.join(model_name_or_path, fileName)
searched_files.append(fileName)
if os.path.isfile(model_save_name):
resolved_archive_file = model_save_name
return False, resolved_archive_file, fileName

else:
temp = None
for ext in extensions:
for possible_model_basename in possible_model_basenames:
shard_index_name = possible_model_basename + ext + ".index.json"
shard_index = cached_file(
model_name_or_path,
shard_index_name,
**cached_file_kwargs,
)
searched_files.append(shard_index_name)
if shard_index is not None:
# The model is sharded over several checkpoints.
with open(str(shard_index)) as f:
index_json = json.load(f)
# Download the shards from the index.json.
shards = list(set(index_json["weight_map"].values()))
for shard in shards:
resolved_archive_file = cached_file(
model_name_or_path,
shard,
**cached_file_kwargs,
)
return True, shard_index, possible_model_basename
else:
resolved_archive_file = cached_file(
files = list_repo_files(model_name_or_path)
for fileName in files:
for ext in extensions:
if ext in fileName:
shard_index_name = fileName + ".index.json"
shard_index = cached_file(
model_name_or_path,
possible_model_basename + ext,
shard_index_name,
**cached_file_kwargs,
)
if resolved_archive_file is None:
resolved_archive_file = temp
searched_files.append(possible_model_basename + ext)
if resolved_archive_file is not None:
temp = resolved_archive_file
return False, resolved_archive_file, possible_model_basename
searched_files.append(shard_index_name)
if shard_index is not None:
# The model is sharded over several checkpoints.
with open(str(shard_index)) as f:
index_json = json.load(f)
# Download the shards from the index.json.
shards = list(set(index_json["weight_map"].values()))
for shard in shards:
resolved_archive_file = cached_file(
model_name_or_path,
shard,
**cached_file_kwargs,
)
return True, shard_index, fileName
else:
resolved_archive_file = cached_file(
model_name_or_path,
fileName,
**cached_file_kwargs,
)
if resolved_archive_file is None:
resolved_archive_file = temp
searched_files.append(fileName)
if resolved_archive_file is not None:
temp = resolved_archive_file
return False, resolved_archive_file, fileName

if resolved_archive_file is None:
raise FileNotFoundError(
Expand Down
3 changes: 0 additions & 3 deletions tests/test_quant_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, forma

with open(tmpdirname + "/" + QUANT_CONFIG_FILENAME, "r") as f:
file_dict = json.loads(f.read())
# skip comparison of these two model path specific fields that do not exist in memory
file_dict["model_name_or_path"] = None
file_dict["model_file_base_name"] = None

# make sure the json dict saved to file matches config in memory
assert model.quantize_config.to_dict() == file_dict
Expand Down
2 changes: 1 addition & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_marlin_local_serialization(self):
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)

self.assertTrue(os.path.isfile(os.path.join(tmpdir, "gptq_model-4bit-128g.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "model.safetensors")))

model = GPTQModel.from_quantized(tmpdir, device="cuda:0", backend=BACKEND.MARLIN)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_save_and_load(self):

del model

index_file_path = os.path.join(tmp_dir, "gptq_model-4bit-128g.safetensors.index.json")
index_file_path = os.path.join(tmp_dir, "model.safetensors.index.json")
self.assertTrue(os.path.exists(index_file_path))

model = GPTQModel.from_quantized(
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_save_and_load_no_shard(self):

del model

safetensors_file_path = os.path.join(tmp_dir, "gptq_model-4bit-128g.safetensors")
safetensors_file_path = os.path.join(tmp_dir, "model.safetensors")
self.assertTrue(os.path.exists(safetensors_file_path))

model = GPTQModel.from_quantized(
Expand Down