diff --git a/gptqmodel/utils/__init__.py b/gptqmodel/utils/__init__.py index c3eb254d9..0876e0190 100644 --- a/gptqmodel/utils/__init__.py +++ b/gptqmodel/utils/__init__.py @@ -1,2 +1,3 @@ from .backend import BACKEND, get_backend from .perplexity import Perplexity +from .vram import get_vram diff --git a/gptqmodel/utils/vram.py b/gptqmodel/utils/vram.py new file mode 100644 index 000000000..3bc09c07a --- /dev/null +++ b/gptqmodel/utils/vram.py @@ -0,0 +1,184 @@ +from accelerate.utils import convert_bytes +from typing import Dict, List, Tuple, Union, Optional +from collections import defaultdict +import torch +import torch.nn as nn +import re + +def dtype_byte_size(dtype: torch.dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(torch.float32) + 4 + ``` + """ + if dtype == torch.bool: + return 1 / 8 + elif dtype == "int2": + return 1 / 4 + elif dtype == "int4": + return 1 / 2 + elif dtype == "fp8": + return 1 + elif dtype == torch.float8_e4m3fn: + return 1 + elif dtype == torch.float16 or dtype == torch.bfloat16: + return 2 + elif dtype == torch.float32 or dtype == torch.int32: + return 4 + else: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + +def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype: + """ + Just does torch.dtype(dtype) if necessary. + """ + if isinstance(dtype, str): + # We accept "torch.float16" or just "float16" + dtype = dtype.replace("torch.", "") + dtype = getattr(torch, dtype) + return dtype + +def named_module_tensors( + module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False +): + """ + A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True` + it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`. + + Args: + module (`torch.nn.Module`): + The module we want the tensors on. + include_buffer (`bool`, *optional*, defaults to `True`): + Whether or not to include the buffers in the result. + recurse (`bool`, *optional`, defaults to `False`): + Whether or not to go look in every submodule or just return the direct parameters and buffers. + remove_non_persistent (`bool`, *optional*, defaults to `False`): + Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers = + True + """ + yield from module.named_parameters(recurse=recurse) + + if include_buffers: + non_persistent_buffers = set() + if remove_non_persistent: + non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse) + for named_buffer in module.named_buffers(recurse=recurse): + name, _ = named_buffer + if name not in non_persistent_buffers: + yield named_buffer + +def get_non_persistent_buffers(module: nn.Module, recurse: bool = False): + """ + Gather all non persistent buffers of a given modules into a set + + Args: + module (`nn.Module`): + The module we want the non persistent buffers on. + recurse (`bool`, *optional*, defaults to `False`): + Whether or not to go look in every submodule or just return the direct non persistent buffers. + """ + + non_persistent_buffers_set = module._non_persistent_buffers_set + if recurse: + for _, m in module.named_modules(): + non_persistent_buffers_set |= m._non_persistent_buffers_set + + return non_persistent_buffers_set + + +def compute_module_sizes( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, + buffers_only: bool = False, +): + """ + Compute the size of each submodule of a given model. + """ + if dtype is not None: + dtype = _get_proper_dtype(dtype) + dtype_size = dtype_byte_size(dtype) + if special_dtypes is not None: + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} + module_sizes = defaultdict(int) + + if not buffers_only: + module_list = named_module_tensors(model, recurse=True) + else: + module_list = model.named_buffers(recurse=True) + + for name, tensor in module_list: + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes_size[name] + elif dtype is None: + size = tensor.numel() * dtype_byte_size(tensor.dtype) + elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + # According to the code in set_module_tensor_to_device, these types won't be converted + # so use their original size here + size = tensor.numel() * dtype_byte_size(tensor.dtype) + else: + size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) + name_parts = name.split(".") + for idx in range(len(name_parts) + 1): + module_sizes[".".join(name_parts[:idx])] += size + + return module_sizes + +def get_all_layer_size( + modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str] +): + """ + from accelerate.utils get_max_layer_size + Utility function that will scan a list of named modules and return the maximum size used by one full layer. The + definition of a layer being: + - a module with no direct children (just parameters and buffers) + - a module whose class name is in the list `no_split_module_classes` + + Args: + modules (`List[Tuple[str, torch.nn.Module]]`): + The list of named modules where we want to determine the maximum layer size. + module_sizes (`Dict[str, int]`): + A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`). + no_split_module_classes (`List[str]`): + A list of class names for layers we don't want to be split. + + Returns: + `List[Tuple[str, str]]`: The size of all layer with the list of layer names and size str. + """ + + layer_sizes = [] + modules_to_treat = modules.copy() + while len(modules_to_treat) > 0: + module_name, module = modules_to_treat.pop(0) + modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else [] + if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: + size = module_sizes[module_name] + layer_sizes.append((module_name, convert_bytes(size))) + else: + modules_to_treat = [(f"{module_name}.{n}", v) for n, v in modules_children] + modules_to_treat + + return layer_sizes + +def get_vram(model): + no_split_modules = getattr(model, "_no_split_modules", None) + if no_split_modules is None: + no_split_modules = [] + modules_to_treat = ( + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) + ) + sizes = compute_module_sizes(model) + total_size = sizes[""] + + total_size = convert_bytes(total_size) + # List[Tuple[str, str]] + all_layers = get_all_layer_size(modules_to_treat, sizes, no_split_modules) + + return total_size, all_layers \ No newline at end of file diff --git a/tests/test_estimate_vram.py b/tests/test_estimate_vram.py new file mode 100644 index 000000000..f2f512222 --- /dev/null +++ b/tests/test_estimate_vram.py @@ -0,0 +1,26 @@ +# -- do not touch +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch + +import tempfile # noqa: E402 +import unittest # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.utils import get_vram + +class TestEstimateVram(unittest.TestCase): + NATIVE_MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + + def test_estimate_vram(self): + model = GPTQModel.from_quantized( + self.NATIVE_MODEL_ID, + ) + + total_size, all_layers = get_vram(model) + print(f"{self.NATIVE_MODEL_ID} estimate vram : {total_size}") + for layer in all_layers: + layer_name, layer_size = layer + print(f"Layer {layer_name}: {layer_size}") + del model + assert total_size == "731.73 MB" \ No newline at end of file