From d0be9b0ed80a8416ba4d0446ee1e4fa86b010943 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Mon, 14 Oct 2024 14:52:35 +0800 Subject: [PATCH 1/2] fix: eval_type --- lagrange/client/client.py | 45 ++++++++++---- lagrange/utils/binary/protobuf/coder.py | 4 +- lagrange/utils/binary/protobuf/models.py | 3 +- lagrange/utils/binary/protobuf/util.py | 78 ++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 14 deletions(-) create mode 100644 lagrange/utils/binary/protobuf/util.py diff --git a/lagrange/client/client.py b/lagrange/client/client.py index 3995177..4048c67 100644 --- a/lagrange/client/client.py +++ b/lagrange/client/client.py @@ -9,6 +9,8 @@ Union, overload, Literal, + TYPE_CHECKING, + cast, ) from collections.abc import Coroutine @@ -489,33 +491,54 @@ async def set_grp_request(self, grp_id: int, grp_req_seq: int, ev_type: int, act ) if rsp.ret_code: raise AssertionError(rsp.ret_code, rsp.err_msg) + # + # @overload + # async def get_user_info(self, uid: str) -> UserInfo: ... + # + # @overload + # async def get_user_info(self, uid: list[str]) -> list[UserInfo]: ... + # + # @overload + # async def get_user_info(self, uin: int) -> UserInfo: + # ... + # + # @overload + # async def get_user_info(self, uin: list[int]) -> list[UserInfo]: + # ... @overload - async def get_user_info(self, uid: str) -> UserInfo: ... + async def get_user_info(self, uid: str, /) -> UserInfo: + ... @overload - async def get_user_info(self, uid: list[str]) -> list[UserInfo]: ... + async def get_user_info(self, uids: list[str], /) -> list[UserInfo]: + ... @overload - async def get_user_info(self, uin: int) -> UserInfo: + async def get_user_info(self, uin: int, /) -> UserInfo: ... @overload - async def get_user_info(self, uin: list[int]) -> list[UserInfo]: + async def get_user_info(self, uins: list[int], /) -> list[UserInfo]: ... async def get_user_info( - self, - uid: Union[str, list[str]] = None, - uin: Union[int, list[int]] = None, + self, + uid_or_uin: Union[str, int, list[str], list[int]], + / ) -> Union[UserInfo, list[UserInfo]]: - userid = uid or uin - assert userid, "empty uid or uin" - if not isinstance(userid, list): - userid = [userid] + if isinstance(uid_or_uin, list): + assert uid_or_uin, "empty uid or uin" + userid = uid_or_uin + else: + userid = [uid_or_uin] if isinstance(userid[0], int): + if TYPE_CHECKING: + userid = cast(list[int], userid) req, sc = PBGetInfoFromUinReq(uin=userid).encode(), 2 elif isinstance(userid[0], str): + if TYPE_CHECKING: + userid = cast(list[str], userid) req, sc = PBGetInfoFromUidReq(uid=userid).encode(), 8 else: raise TypeError(userid[0]) diff --git a/lagrange/utils/binary/protobuf/coder.py b/lagrange/utils/binary/protobuf/coder.py index cb8d128..96d6032 100644 --- a/lagrange/utils/binary/protobuf/coder.py +++ b/lagrange/utils/binary/protobuf/coder.py @@ -27,12 +27,12 @@ def __getitem__(self, item: int) -> "ProtoEncodable": def into(self, field: Union[int, tuple[int, ...]], tp: type[TProtoEncodable]) -> TProtoEncodable: if isinstance(field, int): - return cast(tp, self.proto[field]) + return self.proto[field] # type: ignore else: data = self.proto for f in field: data = data[f] # type: ignore - return cast(tp, data) + return data # type: ignore class ProtoBuilder(Builder): diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index 62d45ca..70d56e9 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -8,6 +8,7 @@ import typing from .coder import Proto, proto_decode, proto_encode +from .util import eval_type _ProtoTypes = Union[str, list, dict, bytes, int, float, bool, "ProtoStruct"] @@ -239,7 +240,7 @@ def update_forwardref(cls, mapping: dict[str, "type[ProtoStruct]"]): typ = field.type if isinstance(typ, str): typ = ForwardRef(typ, is_argument=False, is_class=True) - field.type = typing._eval_type(typ, mapping, mapping) # type: ignore + field.type = eval_type(typ, mapping, mapping) # type: ignore field._unevaluated = False except NameError: pass diff --git a/lagrange/utils/binary/protobuf/util.py b/lagrange/utils/binary/protobuf/util.py new file mode 100644 index 0000000..02b8cb5 --- /dev/null +++ b/lagrange/utils/binary/protobuf/util.py @@ -0,0 +1,78 @@ +import collections.abc +import functools +import operator +import sys +import types +import typing +from typing import ForwardRef, List +import typing_extensions +from typing_extensions import ParamSpec + +_GenericAlias = type(List[int]) +SpecialType = (_GenericAlias, types.GenericAlias) + +if sys.version_info >= (3, 10): + SpecialType += (types.UnionType,) + +if sys.version_info >= (3, 11): + eval_type = typing._eval_type # type: ignore +else: + def _is_param_expr(arg): + return arg is ... or isinstance(arg, (tuple, list, ParamSpec, typing_extensions._ConcatenateGenericAlias)) # type: ignore + + + def _should_unflatten_callable_args(typ, args): + """Internal helper for munging collections.abc.Callable's __args__. + + The canonical representation for a Callable's __args__ flattens the + argument types, see https://github.com/python/cpython/issues/86361. + + For example:: + + >>> import collections.abc + >>> P = ParamSpec('P') + >>> collections.abc.Callable[[int, int], str].__args__ == (int, int, str) + True + >>> collections.abc.Callable[P, str].__args__ == (P, str) + True + + As a result, if we need to reconstruct the Callable from its __args__, + we need to unflatten it. + """ + return ( + typ.__origin__ is collections.abc.Callable + and not (len(args) == 2 and _is_param_expr(args[0])) + ) + + + def eval_type(t, globalns, localns, recursive_guard=frozenset()): + """Evaluate all forward references in the given type t. + + For use of globalns and localns see the docstring for get_type_hints(). + recursive_guard is used to prevent infinite recursion with a recursive + ForwardRef. + """ + if isinstance(t, ForwardRef): + return t._evaluate(globalns, localns, recursive_guard) + if isinstance(t, SpecialType): + if isinstance(t, types.GenericAlias): + args = tuple( + ForwardRef(arg) if isinstance(arg, str) else arg + for arg in t.__args__ + ) + is_unpacked = getattr(t, "__unpacked__", False) + if _should_unflatten_callable_args(t, args): + t = t.__origin__[(args[:-1], args[-1])] # type: ignore + else: + t = t.__origin__[args] # type: ignore + if is_unpacked: + t = typing_extensions.Unpack[t] + ev_args = tuple(eval_type(a, globalns, localns, recursive_guard) for a in t.__args__) # type: ignore + if ev_args == t.__args__: # type: ignore + return t + if isinstance(t, types.GenericAlias): + return types.GenericAlias(t.__origin__, ev_args) + if hasattr(types, "UnionType") and isinstance(t, types.UnionType): # type: ignore + return functools.reduce(operator.or_, ev_args) + return t.copy_with(ev_args) # type: ignore + return t From a9a408b476d3c957f35207a1454fb0db55b355e6 Mon Sep 17 00:00:00 2001 From: nullcat Date: Mon, 14 Oct 2024 18:54:39 +0800 Subject: [PATCH 2/2] Update client.py --- lagrange/client/client.py | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/lagrange/client/client.py b/lagrange/client/client.py index 4048c67..e978289 100644 --- a/lagrange/client/client.py +++ b/lagrange/client/client.py @@ -491,35 +491,13 @@ async def set_grp_request(self, grp_id: int, grp_req_seq: int, ev_type: int, act ) if rsp.ret_code: raise AssertionError(rsp.ret_code, rsp.err_msg) - # - # @overload - # async def get_user_info(self, uid: str) -> UserInfo: ... - # - # @overload - # async def get_user_info(self, uid: list[str]) -> list[UserInfo]: ... - # - # @overload - # async def get_user_info(self, uin: int) -> UserInfo: - # ... - # - # @overload - # async def get_user_info(self, uin: list[int]) -> list[UserInfo]: - # ... @overload - async def get_user_info(self, uid: str, /) -> UserInfo: + async def get_user_info(self, uid_or_uin: Union[str, int], /) -> UserInfo: ... @overload - async def get_user_info(self, uids: list[str], /) -> list[UserInfo]: - ... - - @overload - async def get_user_info(self, uin: int, /) -> UserInfo: - ... - - @overload - async def get_user_info(self, uins: list[int], /) -> list[UserInfo]: + async def get_user_info(self, uid_or_uin: Union[list[str], list[int]], /) -> list[UserInfo]: ... async def get_user_info(