Skip to content

Commit 35ec18a

Browse files
add estimate-vram (#452)
* add estimate-vram * mod clean up * remove unused * update code
1 parent af8b371 commit 35ec18a

File tree

3 files changed

+211
-0
lines changed

3 files changed

+211
-0
lines changed

gptqmodel/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .backend import BACKEND, get_backend
22
from .perplexity import Perplexity
3+
from .vram import get_vram

gptqmodel/utils/vram.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from accelerate.utils import convert_bytes
2+
from typing import Dict, List, Tuple, Union, Optional
3+
from collections import defaultdict
4+
import torch
5+
import torch.nn as nn
6+
import re
7+
8+
def dtype_byte_size(dtype: torch.dtype):
9+
"""
10+
Returns the size (in bytes) occupied by one parameter of type `dtype`.
11+
12+
Example:
13+
14+
```py
15+
>>> dtype_byte_size(torch.float32)
16+
4
17+
```
18+
"""
19+
if dtype == torch.bool:
20+
return 1 / 8
21+
elif dtype == "int2":
22+
return 1 / 4
23+
elif dtype == "int4":
24+
return 1 / 2
25+
elif dtype == "fp8":
26+
return 1
27+
elif dtype == torch.float8_e4m3fn:
28+
return 1
29+
elif dtype == torch.float16 or dtype == torch.bfloat16:
30+
return 2
31+
elif dtype == torch.float32 or dtype == torch.int32:
32+
return 4
33+
else:
34+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
35+
36+
def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:
37+
"""
38+
Just does torch.dtype(dtype) if necessary.
39+
"""
40+
if isinstance(dtype, str):
41+
# We accept "torch.float16" or just "float16"
42+
dtype = dtype.replace("torch.", "")
43+
dtype = getattr(torch, dtype)
44+
return dtype
45+
46+
def named_module_tensors(
47+
module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
48+
):
49+
"""
50+
A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`
51+
it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`.
52+
53+
Args:
54+
module (`torch.nn.Module`):
55+
The module we want the tensors on.
56+
include_buffer (`bool`, *optional*, defaults to `True`):
57+
Whether or not to include the buffers in the result.
58+
recurse (`bool`, *optional`, defaults to `False`):
59+
Whether or not to go look in every submodule or just return the direct parameters and buffers.
60+
remove_non_persistent (`bool`, *optional*, defaults to `False`):
61+
Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers =
62+
True
63+
"""
64+
yield from module.named_parameters(recurse=recurse)
65+
66+
if include_buffers:
67+
non_persistent_buffers = set()
68+
if remove_non_persistent:
69+
non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse)
70+
for named_buffer in module.named_buffers(recurse=recurse):
71+
name, _ = named_buffer
72+
if name not in non_persistent_buffers:
73+
yield named_buffer
74+
75+
def get_non_persistent_buffers(module: nn.Module, recurse: bool = False):
76+
"""
77+
Gather all non persistent buffers of a given modules into a set
78+
79+
Args:
80+
module (`nn.Module`):
81+
The module we want the non persistent buffers on.
82+
recurse (`bool`, *optional*, defaults to `False`):
83+
Whether or not to go look in every submodule or just return the direct non persistent buffers.
84+
"""
85+
86+
non_persistent_buffers_set = module._non_persistent_buffers_set
87+
if recurse:
88+
for _, m in module.named_modules():
89+
non_persistent_buffers_set |= m._non_persistent_buffers_set
90+
91+
return non_persistent_buffers_set
92+
93+
94+
def compute_module_sizes(
95+
model: nn.Module,
96+
dtype: Optional[Union[str, torch.device]] = None,
97+
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
98+
buffers_only: bool = False,
99+
):
100+
"""
101+
Compute the size of each submodule of a given model.
102+
"""
103+
if dtype is not None:
104+
dtype = _get_proper_dtype(dtype)
105+
dtype_size = dtype_byte_size(dtype)
106+
if special_dtypes is not None:
107+
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
108+
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
109+
module_sizes = defaultdict(int)
110+
111+
if not buffers_only:
112+
module_list = named_module_tensors(model, recurse=True)
113+
else:
114+
module_list = model.named_buffers(recurse=True)
115+
116+
for name, tensor in module_list:
117+
if special_dtypes is not None and name in special_dtypes:
118+
size = tensor.numel() * special_dtypes_size[name]
119+
elif dtype is None:
120+
size = tensor.numel() * dtype_byte_size(tensor.dtype)
121+
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
122+
# According to the code in set_module_tensor_to_device, these types won't be converted
123+
# so use their original size here
124+
size = tensor.numel() * dtype_byte_size(tensor.dtype)
125+
else:
126+
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
127+
name_parts = name.split(".")
128+
for idx in range(len(name_parts) + 1):
129+
module_sizes[".".join(name_parts[:idx])] += size
130+
131+
return module_sizes
132+
133+
def get_all_layer_size(
134+
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
135+
):
136+
"""
137+
from accelerate.utils get_max_layer_size
138+
Utility function that will scan a list of named modules and return the maximum size used by one full layer. The
139+
definition of a layer being:
140+
- a module with no direct children (just parameters and buffers)
141+
- a module whose class name is in the list `no_split_module_classes`
142+
143+
Args:
144+
modules (`List[Tuple[str, torch.nn.Module]]`):
145+
The list of named modules where we want to determine the maximum layer size.
146+
module_sizes (`Dict[str, int]`):
147+
A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).
148+
no_split_module_classes (`List[str]`):
149+
A list of class names for layers we don't want to be split.
150+
151+
Returns:
152+
`List[Tuple[str, str]]`: The size of all layer with the list of layer names and size str.
153+
"""
154+
155+
layer_sizes = []
156+
modules_to_treat = modules.copy()
157+
while len(modules_to_treat) > 0:
158+
module_name, module = modules_to_treat.pop(0)
159+
modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else []
160+
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
161+
size = module_sizes[module_name]
162+
layer_sizes.append((module_name, convert_bytes(size)))
163+
else:
164+
modules_to_treat = [(f"{module_name}.{n}", v) for n, v in modules_children] + modules_to_treat
165+
166+
return layer_sizes
167+
168+
def get_vram(model):
169+
no_split_modules = getattr(model, "_no_split_modules", None)
170+
if no_split_modules is None:
171+
no_split_modules = []
172+
modules_to_treat = (
173+
list(model.named_parameters(recurse=False))
174+
+ list(model.named_children())
175+
+ list(model.named_buffers(recurse=False))
176+
)
177+
sizes = compute_module_sizes(model)
178+
total_size = sizes[""]
179+
180+
total_size = convert_bytes(total_size)
181+
# List[Tuple[str, str]]
182+
all_layers = get_all_layer_size(modules_to_treat, sizes, no_split_modules)
183+
184+
return total_size, all_layers

tests/test_estimate_vram.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# -- do not touch
2+
import os
3+
4+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
5+
# -- end do not touch
6+
7+
import tempfile # noqa: E402
8+
import unittest # noqa: E402
9+
from gptqmodel import BACKEND, GPTQModel # noqa: E402
10+
from gptqmodel.utils import get_vram
11+
12+
class TestEstimateVram(unittest.TestCase):
13+
NATIVE_MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
14+
15+
def test_estimate_vram(self):
16+
model = GPTQModel.from_quantized(
17+
self.NATIVE_MODEL_ID,
18+
)
19+
20+
total_size, all_layers = get_vram(model)
21+
print(f"{self.NATIVE_MODEL_ID} estimate vram : {total_size}")
22+
for layer in all_layers:
23+
layer_name, layer_size = layer
24+
print(f"Layer {layer_name}: {layer_size}")
25+
del model
26+
assert total_size == "731.73 MB"

0 commit comments

Comments
 (0)