1
+ from accelerate .utils import convert_bytes
2
+ from typing import Dict , List , Tuple , Union , Optional
3
+ from collections import defaultdict
4
+ import torch
5
+ import torch .nn as nn
6
+ import re
7
+
8
+ def dtype_byte_size (dtype : torch .dtype ):
9
+ """
10
+ Returns the size (in bytes) occupied by one parameter of type `dtype`.
11
+
12
+ Example:
13
+
14
+ ```py
15
+ >>> dtype_byte_size(torch.float32)
16
+ 4
17
+ ```
18
+ """
19
+ if dtype == torch .bool :
20
+ return 1 / 8
21
+ elif dtype == "int2" :
22
+ return 1 / 4
23
+ elif dtype == "int4" :
24
+ return 1 / 2
25
+ elif dtype == "fp8" :
26
+ return 1
27
+ elif dtype == torch .float8_e4m3fn :
28
+ return 1
29
+ elif dtype == torch .float16 or dtype == torch .bfloat16 :
30
+ return 2
31
+ elif dtype == torch .float32 or dtype == torch .int32 :
32
+ return 4
33
+ else :
34
+ raise ValueError (f"`dtype` is not a valid dtype: { dtype } ." )
35
+
36
+ def _get_proper_dtype (dtype : Union [str , torch .device ]) -> torch .dtype :
37
+ """
38
+ Just does torch.dtype(dtype) if necessary.
39
+ """
40
+ if isinstance (dtype , str ):
41
+ # We accept "torch.float16" or just "float16"
42
+ dtype = dtype .replace ("torch." , "" )
43
+ dtype = getattr (torch , dtype )
44
+ return dtype
45
+
46
+ def named_module_tensors (
47
+ module : nn .Module , include_buffers : bool = True , recurse : bool = False , remove_non_persistent : bool = False
48
+ ):
49
+ """
50
+ A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`
51
+ it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`.
52
+
53
+ Args:
54
+ module (`torch.nn.Module`):
55
+ The module we want the tensors on.
56
+ include_buffer (`bool`, *optional*, defaults to `True`):
57
+ Whether or not to include the buffers in the result.
58
+ recurse (`bool`, *optional`, defaults to `False`):
59
+ Whether or not to go look in every submodule or just return the direct parameters and buffers.
60
+ remove_non_persistent (`bool`, *optional*, defaults to `False`):
61
+ Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers =
62
+ True
63
+ """
64
+ yield from module .named_parameters (recurse = recurse )
65
+
66
+ if include_buffers :
67
+ non_persistent_buffers = set ()
68
+ if remove_non_persistent :
69
+ non_persistent_buffers = get_non_persistent_buffers (module , recurse = recurse )
70
+ for named_buffer in module .named_buffers (recurse = recurse ):
71
+ name , _ = named_buffer
72
+ if name not in non_persistent_buffers :
73
+ yield named_buffer
74
+
75
+ def get_non_persistent_buffers (module : nn .Module , recurse : bool = False ):
76
+ """
77
+ Gather all non persistent buffers of a given modules into a set
78
+
79
+ Args:
80
+ module (`nn.Module`):
81
+ The module we want the non persistent buffers on.
82
+ recurse (`bool`, *optional*, defaults to `False`):
83
+ Whether or not to go look in every submodule or just return the direct non persistent buffers.
84
+ """
85
+
86
+ non_persistent_buffers_set = module ._non_persistent_buffers_set
87
+ if recurse :
88
+ for _ , m in module .named_modules ():
89
+ non_persistent_buffers_set |= m ._non_persistent_buffers_set
90
+
91
+ return non_persistent_buffers_set
92
+
93
+
94
+ def compute_module_sizes (
95
+ model : nn .Module ,
96
+ dtype : Optional [Union [str , torch .device ]] = None ,
97
+ special_dtypes : Optional [Dict [str , Union [str , torch .device ]]] = None ,
98
+ buffers_only : bool = False ,
99
+ ):
100
+ """
101
+ Compute the size of each submodule of a given model.
102
+ """
103
+ if dtype is not None :
104
+ dtype = _get_proper_dtype (dtype )
105
+ dtype_size = dtype_byte_size (dtype )
106
+ if special_dtypes is not None :
107
+ special_dtypes = {key : _get_proper_dtype (dtyp ) for key , dtyp in special_dtypes .items ()}
108
+ special_dtypes_size = {key : dtype_byte_size (dtyp ) for key , dtyp in special_dtypes .items ()}
109
+ module_sizes = defaultdict (int )
110
+
111
+ if not buffers_only :
112
+ module_list = named_module_tensors (model , recurse = True )
113
+ else :
114
+ module_list = model .named_buffers (recurse = True )
115
+
116
+ for name , tensor in module_list :
117
+ if special_dtypes is not None and name in special_dtypes :
118
+ size = tensor .numel () * special_dtypes_size [name ]
119
+ elif dtype is None :
120
+ size = tensor .numel () * dtype_byte_size (tensor .dtype )
121
+ elif str (tensor .dtype ).startswith (("torch.uint" , "torch.int" , "torch.bool" )):
122
+ # According to the code in set_module_tensor_to_device, these types won't be converted
123
+ # so use their original size here
124
+ size = tensor .numel () * dtype_byte_size (tensor .dtype )
125
+ else :
126
+ size = tensor .numel () * min (dtype_size , dtype_byte_size (tensor .dtype ))
127
+ name_parts = name .split ("." )
128
+ for idx in range (len (name_parts ) + 1 ):
129
+ module_sizes ["." .join (name_parts [:idx ])] += size
130
+
131
+ return module_sizes
132
+
133
+ def get_all_layer_size (
134
+ modules : List [Tuple [str , torch .nn .Module ]], module_sizes : Dict [str , int ], no_split_module_classes : List [str ]
135
+ ):
136
+ """
137
+ from accelerate.utils get_max_layer_size
138
+ Utility function that will scan a list of named modules and return the maximum size used by one full layer. The
139
+ definition of a layer being:
140
+ - a module with no direct children (just parameters and buffers)
141
+ - a module whose class name is in the list `no_split_module_classes`
142
+
143
+ Args:
144
+ modules (`List[Tuple[str, torch.nn.Module]]`):
145
+ The list of named modules where we want to determine the maximum layer size.
146
+ module_sizes (`Dict[str, int]`):
147
+ A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).
148
+ no_split_module_classes (`List[str]`):
149
+ A list of class names for layers we don't want to be split.
150
+
151
+ Returns:
152
+ `List[Tuple[str, str]]`: The size of all layer with the list of layer names and size str.
153
+ """
154
+
155
+ layer_sizes = []
156
+ modules_to_treat = modules .copy ()
157
+ while len (modules_to_treat ) > 0 :
158
+ module_name , module = modules_to_treat .pop (0 )
159
+ modules_children = list (module .named_children ()) if isinstance (module , torch .nn .Module ) else []
160
+ if len (modules_children ) == 0 or module .__class__ .__name__ in no_split_module_classes :
161
+ size = module_sizes [module_name ]
162
+ layer_sizes .append ((module_name , convert_bytes (size )))
163
+ else :
164
+ modules_to_treat = [(f"{ module_name } .{ n } " , v ) for n , v in modules_children ] + modules_to_treat
165
+
166
+ return layer_sizes
167
+
168
+ def get_vram (model ):
169
+ no_split_modules = getattr (model , "_no_split_modules" , None )
170
+ if no_split_modules is None :
171
+ no_split_modules = []
172
+ modules_to_treat = (
173
+ list (model .named_parameters (recurse = False ))
174
+ + list (model .named_children ())
175
+ + list (model .named_buffers (recurse = False ))
176
+ )
177
+ sizes = compute_module_sizes (model )
178
+ total_size = sizes ["" ]
179
+
180
+ total_size = convert_bytes (total_size )
181
+ # List[Tuple[str, str]]
182
+ all_layers = get_all_layer_size (modules_to_treat , sizes , no_split_modules )
183
+
184
+ return total_size , all_layers
0 commit comments