1
+ import sys
1
2
from dataclasses import MISSING
2
3
from types import GenericAlias
3
- from typing import cast , TypeVar , Union , Any , Callable , overload , get_origin , get_args , get_type_hints
4
+ from typing import cast , TypeVar , Union , Any , Callable , overload , get_origin , get_args , ForwardRef
4
5
from collections .abc import Mapping
5
6
from typing_extensions import Self , TypeAlias , dataclass_transform
6
7
from typing import Optional , ClassVar
@@ -31,8 +32,8 @@ def __init__(self, tag: int, default: Any, default_factory: Any):
31
32
if tag <= 0 :
32
33
raise ValueError ("Tag must be a positive integer" )
33
34
self .tag = tag
34
- self ._default = default
35
- self ._default_factory = default_factory
35
+ self .default = default
36
+ self .default_factory = default_factory
36
37
self ._unevaluated = False
37
38
38
39
def ensure_annotation (self , name : str , type_ : Any ) -> None :
@@ -44,10 +45,10 @@ def ensure_annotation(self, name: str, type_: Any) -> None:
44
45
self ._unevaluated = True
45
46
46
47
def get_default (self ) -> Any :
47
- if self ._default is not MISSING :
48
- return self ._default
49
- elif self ._default_factory is not MISSING :
50
- return self ._default_factory ()
48
+ if self .default is not MISSING :
49
+ return self .default
50
+ elif self .default_factory is not MISSING :
51
+ return self .default_factory ()
51
52
return MISSING
52
53
53
54
@property
@@ -166,15 +167,16 @@ def check_type(value: Any, typ: Any) -> bool:
166
167
return False # or True if value is None else False
167
168
168
169
170
+ _unevaluated_classes : set [type ["ProtoStruct" ]] = set ()
171
+
172
+
169
173
@dataclass_transform (kw_only_default = True , field_specifiers = (proto_field ,))
170
174
class ProtoStruct :
171
175
__proto_fields__ : ClassVar [dict [str , ProtoField ]]
172
176
__proto_debug__ : ClassVar [bool ]
173
- __proto_evaluated__ : ClassVar [bool ] = False
174
177
175
178
def __init__ (self , __from_raw : bool = False , / , ** kwargs ):
176
179
undefined_params : list [ProtoField ] = []
177
- self ._evaluate ()
178
180
for name , field in self .__proto_fields__ .items ():
179
181
if name in kwargs :
180
182
value = kwargs .pop (name )
@@ -212,11 +214,13 @@ def _process_field(cls):
212
214
if field is MISSING :
213
215
raise TypeError (f'{ name !r} should define its proto_field!' )
214
216
field .ensure_annotation (name , typ )
217
+ if field ._unevaluated :
218
+ _unevaluated_classes .add (cls )
215
219
cls_fields .append (field )
216
220
217
221
for f in cls_fields :
218
222
fields [f .name ] = f
219
- if f ._default is MISSING :
223
+ if f .default is MISSING :
220
224
delattr (cls , f .name )
221
225
222
226
for name , value in cls .__dict__ .items ():
@@ -225,28 +229,20 @@ def _process_field(cls):
225
229
226
230
cls .__proto_fields__ = fields
227
231
228
- @classmethod
229
- def _evaluate (cls ):
230
- for base in reversed (cls .__mro__ ):
231
- if base in (ProtoStruct , object ):
232
- continue
233
- if getattr (base , '__proto_evaluated__' , False ):
234
- continue
235
- try :
236
- annotations = get_type_hints (base )
237
- except NameError :
238
- annotations = {}
239
- for field in base .__proto_fields__ .values ():
240
- if field ._unevaluated and field .name in annotations :
241
- field .type = annotations [field .name ]
242
- base .__proto_evaluated__ = True
243
-
244
232
@classmethod
245
233
def update_forwardref (cls , mapping : dict [str , "type[ProtoStruct]" ]):
246
234
"""更新 ForwardRef"""
247
235
for field in cls .__proto_fields__ .values ():
248
- if field ._unevaluated :
249
- field .type = typing ._eval_type (field .type , mapping , mapping ) # type: ignore
236
+ if not field ._unevaluated :
237
+ continue
238
+ try :
239
+ typ = field .type
240
+ if isinstance (typ , str ):
241
+ typ = ForwardRef (typ , is_argument = False , is_class = True )
242
+ field .type = typing ._eval_type (typ , mapping , mapping ) # type: ignore
243
+ field ._unevaluated = False
244
+ except NameError :
245
+ pass
250
246
251
247
def __init_subclass__ (cls , ** kwargs ):
252
248
cls .__proto_debug__ = kwargs .pop ("debug" ) if "debug" in kwargs else False
@@ -295,3 +291,21 @@ def decode(cls, data: bytes) -> Self:
295
291
if pb_dict and cls .__proto_debug__ : # unhandled tags
296
292
pass
297
293
return cls (True , ** kwargs )
294
+
295
+
296
+ def evaluate_all ():
297
+ modules = set ()
298
+ for cls in _unevaluated_classes :
299
+ modules .add (cls .__module__ )
300
+ for base in cls .__mro__ [- 1 :0 :- 1 ][2 :]:
301
+ modules .add (base .__module__ )
302
+ globalns = {}
303
+ for module in modules :
304
+ globalns .update (getattr (sys .modules .get (module , None ), "__dict__" , {}))
305
+ for cls in _unevaluated_classes :
306
+ cls .update_forwardref (globalns )
307
+ _unevaluated_classes .clear ()
308
+ modules .clear ()
309
+ globalns .clear ()
310
+ del modules
311
+ del globalns
0 commit comments