Skip to content

Commit ce26c0b

Browse files
RF-Tar-Railtwyapx
andauthored
fix(pb): upload_forwardref (#44)
* fix: eval_type * Update client.py --------- Co-authored-by: nullcat <[email protected]>
1 parent f61ac5d commit ce26c0b

File tree

4 files changed

+98
-18
lines changed

4 files changed

+98
-18
lines changed

lagrange/client/client.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
Union,
1010
overload,
1111
Literal,
12+
TYPE_CHECKING,
13+
cast,
1214
)
1315
from collections.abc import Coroutine
1416

@@ -491,31 +493,30 @@ async def set_grp_request(self, grp_id: int, grp_req_seq: int, ev_type: int, act
491493
raise AssertionError(rsp.ret_code, rsp.err_msg)
492494

493495
@overload
494-
async def get_user_info(self, uid: str) -> UserInfo: ...
495-
496-
@overload
497-
async def get_user_info(self, uid: list[str]) -> list[UserInfo]: ...
498-
499-
@overload
500-
async def get_user_info(self, uin: int) -> UserInfo:
496+
async def get_user_info(self, uid_or_uin: Union[str, int], /) -> UserInfo:
501497
...
502498

503499
@overload
504-
async def get_user_info(self, uin: list[int]) -> list[UserInfo]:
500+
async def get_user_info(self, uid_or_uin: Union[list[str], list[int]], /) -> list[UserInfo]:
505501
...
506502

507503
async def get_user_info(
508-
self,
509-
uid: Union[str, list[str]] = None,
510-
uin: Union[int, list[int]] = None,
504+
self,
505+
uid_or_uin: Union[str, int, list[str], list[int]],
506+
/
511507
) -> Union[UserInfo, list[UserInfo]]:
512-
userid = uid or uin
513-
assert userid, "empty uid or uin"
514-
if not isinstance(userid, list):
515-
userid = [userid]
508+
if isinstance(uid_or_uin, list):
509+
assert uid_or_uin, "empty uid or uin"
510+
userid = uid_or_uin
511+
else:
512+
userid = [uid_or_uin]
516513
if isinstance(userid[0], int):
514+
if TYPE_CHECKING:
515+
userid = cast(list[int], userid)
517516
req, sc = PBGetInfoFromUinReq(uin=userid).encode(), 2
518517
elif isinstance(userid[0], str):
518+
if TYPE_CHECKING:
519+
userid = cast(list[str], userid)
519520
req, sc = PBGetInfoFromUidReq(uid=userid).encode(), 8
520521
else:
521522
raise TypeError(userid[0])

lagrange/utils/binary/protobuf/coder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def __getitem__(self, item: int) -> "ProtoEncodable":
2727

2828
def into(self, field: Union[int, tuple[int, ...]], tp: type[TProtoEncodable]) -> TProtoEncodable:
2929
if isinstance(field, int):
30-
return cast(tp, self.proto[field])
30+
return self.proto[field] # type: ignore
3131
else:
3232
data = self.proto
3333
for f in field:
3434
data = data[f] # type: ignore
35-
return cast(tp, data)
35+
return data # type: ignore
3636

3737

3838
class ProtoBuilder(Builder):

lagrange/utils/binary/protobuf/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import typing
99

1010
from .coder import Proto, proto_decode, proto_encode
11+
from .util import eval_type
1112

1213
_ProtoTypes = Union[str, list, dict, bytes, int, float, bool, "ProtoStruct"]
1314

@@ -239,7 +240,7 @@ def update_forwardref(cls, mapping: dict[str, "type[ProtoStruct]"]):
239240
typ = field.type
240241
if isinstance(typ, str):
241242
typ = ForwardRef(typ, is_argument=False, is_class=True)
242-
field.type = typing._eval_type(typ, mapping, mapping) # type: ignore
243+
field.type = eval_type(typ, mapping, mapping) # type: ignore
243244
field._unevaluated = False
244245
except NameError:
245246
pass
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import collections.abc
2+
import functools
3+
import operator
4+
import sys
5+
import types
6+
import typing
7+
from typing import ForwardRef, List
8+
import typing_extensions
9+
from typing_extensions import ParamSpec
10+
11+
_GenericAlias = type(List[int])
12+
SpecialType = (_GenericAlias, types.GenericAlias)
13+
14+
if sys.version_info >= (3, 10):
15+
SpecialType += (types.UnionType,)
16+
17+
if sys.version_info >= (3, 11):
18+
eval_type = typing._eval_type # type: ignore
19+
else:
20+
def _is_param_expr(arg):
21+
return arg is ... or isinstance(arg, (tuple, list, ParamSpec, typing_extensions._ConcatenateGenericAlias)) # type: ignore
22+
23+
24+
def _should_unflatten_callable_args(typ, args):
25+
"""Internal helper for munging collections.abc.Callable's __args__.
26+
27+
The canonical representation for a Callable's __args__ flattens the
28+
argument types, see https://github.com/python/cpython/issues/86361.
29+
30+
For example::
31+
32+
>>> import collections.abc
33+
>>> P = ParamSpec('P')
34+
>>> collections.abc.Callable[[int, int], str].__args__ == (int, int, str)
35+
True
36+
>>> collections.abc.Callable[P, str].__args__ == (P, str)
37+
True
38+
39+
As a result, if we need to reconstruct the Callable from its __args__,
40+
we need to unflatten it.
41+
"""
42+
return (
43+
typ.__origin__ is collections.abc.Callable
44+
and not (len(args) == 2 and _is_param_expr(args[0]))
45+
)
46+
47+
48+
def eval_type(t, globalns, localns, recursive_guard=frozenset()):
49+
"""Evaluate all forward references in the given type t.
50+
51+
For use of globalns and localns see the docstring for get_type_hints().
52+
recursive_guard is used to prevent infinite recursion with a recursive
53+
ForwardRef.
54+
"""
55+
if isinstance(t, ForwardRef):
56+
return t._evaluate(globalns, localns, recursive_guard)
57+
if isinstance(t, SpecialType):
58+
if isinstance(t, types.GenericAlias):
59+
args = tuple(
60+
ForwardRef(arg) if isinstance(arg, str) else arg
61+
for arg in t.__args__
62+
)
63+
is_unpacked = getattr(t, "__unpacked__", False)
64+
if _should_unflatten_callable_args(t, args):
65+
t = t.__origin__[(args[:-1], args[-1])] # type: ignore
66+
else:
67+
t = t.__origin__[args] # type: ignore
68+
if is_unpacked:
69+
t = typing_extensions.Unpack[t]
70+
ev_args = tuple(eval_type(a, globalns, localns, recursive_guard) for a in t.__args__) # type: ignore
71+
if ev_args == t.__args__: # type: ignore
72+
return t
73+
if isinstance(t, types.GenericAlias):
74+
return types.GenericAlias(t.__origin__, ev_args)
75+
if hasattr(types, "UnionType") and isinstance(t, types.UnionType): # type: ignore
76+
return functools.reduce(operator.or_, ev_args)
77+
return t.copy_with(ev_args) # type: ignore
78+
return t

0 commit comments

Comments
 (0)