diff --git a/lagrange/__init__.py b/lagrange/__init__.py index febc07b..47bf9da 100644 --- a/lagrange/__init__.py +++ b/lagrange/__init__.py @@ -9,6 +9,7 @@ from .utils.sign import sign_provider from .info import InfoManager from .info.app import app_list +from .utils.binary.protobuf.models import evaluate_all class Lagrange: @@ -66,3 +67,6 @@ def launch(self): log.root.info("Program exited by user") else: log.root.info("Program exited normally") + + +evaluate_all() diff --git a/lagrange/client/event.py b/lagrange/client/event.py index 64175f2..edf7ecc 100644 --- a/lagrange/client/event.py +++ b/lagrange/client/event.py @@ -32,7 +32,7 @@ async def _task_exec(self, client: "Client", event: "BaseEvent", handler: EVENT_ try: await handler(client, event) except Exception as e: - log.root.error( + log.root.exception( f"Unhandled exception on task {event}", exc_info=e ) diff --git a/lagrange/pb/highway/comm.py b/lagrange/pb/highway/comm.py index e928728..6ed5579 100644 --- a/lagrange/pb/highway/comm.py +++ b/lagrange/pb/highway/comm.py @@ -34,9 +34,9 @@ class AudioExtInfo(ProtoStruct): class ExtBizInfo(ProtoStruct): - pic: PicExtInfo = proto_field(1, default=PicExtInfo()) - video: VideoExtInfo = proto_field(2, default=VideoExtInfo()) - audio: AudioExtInfo = proto_field(3, default=AudioExtInfo()) + pic: PicExtInfo = proto_field(1, default_factory=PicExtInfo) + video: VideoExtInfo = proto_field(2, default_factory=VideoExtInfo) + audio: AudioExtInfo = proto_field(3, default_factory=AudioExtInfo) bus_type: Optional[int] = proto_field(4, default=None) diff --git a/lagrange/pb/highway/req.py b/lagrange/pb/highway/req.py index 5740913..7855d4f 100644 --- a/lagrange/pb/highway/req.py +++ b/lagrange/pb/highway/req.py @@ -63,13 +63,13 @@ class DownloadVideoExt(ProtoStruct): class DownloadExt(ProtoStruct): pic_ext: Optional[bytes] = proto_field(1, default=None) - video_ext: DownloadVideoExt = proto_field(2, default=DownloadVideoExt()) + video_ext: DownloadVideoExt = proto_field(2, default_factory=DownloadVideoExt) ptt_ext: Optional[bytes] = proto_field(3, default=None) class DownloadReq(ProtoStruct): node: IndexNode = proto_field(1) - ext: DownloadExt = proto_field(2, default=DownloadExt()) + ext: DownloadExt = proto_field(2, default_factory=DownloadExt) class NTV2RichMediaReq(ProtoStruct): diff --git a/lagrange/pb/highway/rsp.py b/lagrange/pb/highway/rsp.py index 07e79c2..11dff78 100644 --- a/lagrange/pb/highway/rsp.py +++ b/lagrange/pb/highway/rsp.py @@ -41,8 +41,8 @@ class DownloadInfo(ProtoStruct): domain: str = proto_field(1) url_path: str = proto_field(2) https_port: Optional[int] = proto_field(3, default=None) - v4_addrs: list[IPv4] = proto_field(4, default=[]) - v6_addrs: list[IPv6] = proto_field(5, default=[]) + v4_addrs: list[IPv4] = proto_field(4, default_factory=list) + v6_addrs: list[IPv6] = proto_field(5, default_factory=list) pic_info: Optional[PicUrlExtInfo] = proto_field(6, default=None) video_info: Optional[VideoExtInfo] = proto_field(7, default=None) diff --git a/lagrange/pb/message/rich_text/elems.py b/lagrange/pb/message/rich_text/elems.py index 10be308..43b9b3a 100644 --- a/lagrange/pb/message/rich_text/elems.py +++ b/lagrange/pb/message/rich_text/elems.py @@ -53,7 +53,7 @@ class NotOnlineImage(ProtoStruct): width: int = proto_field(9) res_id: str = proto_field(10) origin_path: Optional[str] = proto_field(15, default=None) - args: ImageReserveArgs = proto_field(34, default=ImageReserveArgs()) + args: ImageReserveArgs = proto_field(34, default_factory=ImageReserveArgs) class TransElem(ProtoStruct): @@ -90,7 +90,7 @@ class CustomFace(ProtoStruct): width: int = proto_field(22) height: int = proto_field(23) size: int = proto_field(25) - args: ImageReserveArgs = proto_field(34, default=ImageReserveArgs()) + args: ImageReserveArgs = proto_field(34, default_factory=ImageReserveArgs) class ExtraInfo(ProtoStruct): @@ -110,7 +110,7 @@ class SrcMsg(ProtoStruct): seq: int = proto_field(1) uin: int = proto_field(2, default=0) timestamp: int = proto_field(3) - elems: list[dict] = proto_field(5, default=[{}]) + elems: list[dict] = proto_field(5, default_factory=lambda: [{}]) pb_reserved: Optional[SrcMsgArgs] = proto_field(8, default=None) to_uin: int = proto_field(10, default=0) diff --git a/lagrange/pb/service/friend.py b/lagrange/pb/service/friend.py index 28e0642..8a198c5 100644 --- a/lagrange/pb/service/friend.py +++ b/lagrange/pb/service/friend.py @@ -45,7 +45,7 @@ class PBGetFriendListRequest(ProtoStruct): f7: int = proto_field(7, default=2147483647) # MaxValue body: list[GetFriendBody] = proto_field( 10001, - default=[ + default_factory=lambda: [ GetFriendBody(type=1, f2=GetFriendNumbers(f1=[103, 102, 20002, 27394])), GetFriendBody(type=4, f2=GetFriendNumbers(f1=[100, 101, 102])), ], diff --git a/lagrange/pb/service/group.py b/lagrange/pb/service/group.py index 4696834..3ac60ad 100644 --- a/lagrange/pb/service/group.py +++ b/lagrange/pb/service/group.py @@ -319,6 +319,12 @@ class MemberInfoLevel(ProtoStruct): num: int = proto_field(2) +class GetGrpMemberInfoRsp(ProtoStruct): + grp_id: int = proto_field(1) + body: "list[GetGrpMemberInfoRspBody]" = proto_field(2) + next_key: Optional[bytes] = proto_field(15, default=None) # base64(pb) + + class GetGrpMemberInfoRspBody(ProtoStruct): account: AccountInfo = proto_field(1) nickname: str = proto_field(10, default="") @@ -342,12 +348,6 @@ def is_owner(self) -> bool: return not self.is_admin and self.permission == 2 -class GetGrpMemberInfoRsp(ProtoStruct): - grp_id: int = proto_field(1) - body: list[GetGrpMemberInfoRspBody] = proto_field(2) - next_key: Optional[bytes] = proto_field(15, default=None) # base64(pb) - - class GetGrpListReqBody(ProtoStruct): cfg1: bytes = proto_field(1) cfg2: bytes = proto_field(2) @@ -396,7 +396,7 @@ class GrpInfo(ProtoStruct): class GetGrpListResponse(ProtoStruct): - grp_list: list[GrpInfo] = proto_field(2, default=[]) + grp_list: list[GrpInfo] = proto_field(2, default_factory=list) class PBGetInfoFromUidReq(ProtoStruct): @@ -425,8 +425,12 @@ def to_str(self) -> str: class GetInfoRspField(ProtoStruct, debug=True): - int_t: list[GetInfoRspF1] = proto_field(1, default=[]) - str_t: list[GetInfoRspF2] = proto_field(2, default=[]) + int_t: list[GetInfoRspF1] = proto_field(1, default_factory=list) + str_t: list[GetInfoRspF2] = proto_field(2, default_factory=list) + + +class GetInfoFromUidRsp(ProtoStruct): + body: list["GetInfoRspBody"] = proto_field(1) class GetInfoRspBody(ProtoStruct): @@ -434,10 +438,6 @@ class GetInfoRspBody(ProtoStruct): fields: GetInfoRspField = proto_field(2) -class GetInfoFromUidRsp(ProtoStruct): - body: list[GetInfoRspBody] = proto_field(1) - - class Oidb88D0Args(ProtoStruct): seq: Optional[int] = proto_field(22, default=None) diff --git a/lagrange/pb/status/group.py b/lagrange/pb/status/group.py index fb7d59c..0f17ad3 100644 --- a/lagrange/pb/status/group.py +++ b/lagrange/pb/status/group.py @@ -111,6 +111,13 @@ class GroupSub16Head(ProtoStruct): f44: Optional[PBGroupReaction] = proto_field(44, default=None) # set reaction only +class GroupSub20Head(ProtoStruct): + f1: int = proto_field(1) # 20 + grp_id: int = proto_field(4) + f13: int = proto_field(13) # 19 + body: "GroupSub20Body" = proto_field(26) + + class GroupSub20Body(ProtoStruct): type: int = proto_field(1) # 12: nudge, 14: group_sign # f2: int = proto_field(2) # 1061 @@ -121,13 +128,6 @@ class GroupSub20Body(ProtoStruct): f10: int = proto_field(10) # rand? -class GroupSub20Head(ProtoStruct): - f1: int = proto_field(1) # 20 - grp_id: int = proto_field(4) - f13: int = proto_field(13) # 19 - body: GroupSub20Body = proto_field(26) - - class PBGroupAlbumUpdateBody(ProtoStruct): # f1: 6 args: str = proto_field(2) diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index 3075a0e..6f7c0a4 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -1,10 +1,11 @@ -import inspect +import sys from dataclasses import MISSING from types import GenericAlias -from typing import cast, TypeVar, Union, Any, Callable, overload +from typing import cast, TypeVar, Union, Any, Callable, overload, get_origin, get_args, ForwardRef from collections.abc import Mapping from typing_extensions import Self, TypeAlias, dataclass_transform -from typing import Optional +from typing import Optional, ClassVar +import typing from .coder import Proto, proto_decode, proto_encode @@ -16,25 +17,46 @@ NoneType = type(None) +def _get_all_args(tp): + if args := get_args(tp): + for arg in args: + yield from _get_all_args(arg) + yield from args + + class ProtoField: + name: str + type: Any + def __init__(self, tag: int, default: Any, default_factory: Any): if tag <= 0: raise ValueError("Tag must be a positive integer") - self._tag = tag - self._default = default - self._default_factory = default_factory - - @property - def tag(self) -> int: - return self._tag + self.tag = tag + self.default = default + self.default_factory = default_factory + self._unevaluated = False + + def ensure_annotation(self, name: str, type_: Any) -> None: + self.name = name + self.type = type_ + if isinstance(type_, str): + self._unevaluated = True + elif (args := [*_get_all_args(type_)]) and any(isinstance(a, str) for a in args): + self._unevaluated = True def get_default(self) -> Any: - if self._default is not MISSING: - return self._default - elif self._default_factory is not MISSING: - return self._default_factory() + if self.default is not MISSING: + return self.default + elif self.default_factory is not MISSING: + return self.default_factory() return MISSING + @property + def type_without_optional(self) -> Any: + if get_origin(self.type) is Union: + return get_args(self.type)[0] + return self.type + @overload # `default` and `default_factory` are optional and mutually exclusive. def proto_field( @@ -87,155 +109,203 @@ def proto_field( return ProtoField(tag, default, default_factory) +def _decode(typ: type[_ProtoTypes], raw): + if isinstance(typ, str): + raise ValueError( + f"ForwardRef '{typ}' not resolved. " + f"Please call ProtoStruct.update_forwardref({{'{typ}': {typ}}}) before decoding" + ) + if issubclass(typ, ProtoStruct): + return typ.decode(raw) + elif typ is str: + return raw.decode(errors="ignore") + elif typ is dict: + return proto_decode(raw).proto + elif typ is bool: + return raw == 1 + elif typ is list: + if not isinstance(raw, list): + return [raw] + return raw + elif isinstance(typ, GenericAlias) and get_origin(typ) is list: + real_typ = get_args(typ)[0] + ret = [] + if isinstance(raw, list): + for v in raw: + ret.append(_decode(real_typ, v)) + else: + ret.append(_decode(real_typ, raw)) + return ret + elif isinstance(raw, typ): + return raw + else: + raise NotImplementedError(f"unknown type '{typ}' and data {raw}") + + +def check_type(value: Any, typ: Any) -> bool: + if isinstance(typ, str): + raise ValueError( + f"ForwardRef '{typ}' not resolved. " + f"Please call ProtoStruct.update_forwardref({{'{typ}': {typ}}}) before decoding" + ) + if typ is Any: + return True + if typ is list: + return isinstance(value, list) + if typ is dict: + return isinstance(value, dict) + if isinstance(typ, GenericAlias): + if get_origin(typ) is list: + return all(check_type(v, get_args(typ)[0]) for v in value) + if get_origin(typ) is dict: + return all(check_type(k, get_args(typ)[0]) and check_type(v, get_args(typ)[1]) for k, v in value.items()) + return False + if get_origin(typ) is Union: # Should Only be Optional + return check_type(value, get_args(typ)[0]) if value is not None else True + if isinstance(value, typ): + return True + return False # or True if value is None else False + + +_unevaluated_classes: set[type["ProtoStruct"]] = set() + + @dataclass_transform(kw_only_default=True, field_specifiers=(proto_field,)) class ProtoStruct: - _anno_map: dict[str, tuple[type[_ProtoTypes], ProtoField]] - _proto_debug: bool - - def __init__(self, *args, **kwargs): - undefined_params: list[str] = [] - args = list(args) - for name, (typ, field) in self._anno_map.items(): - if args: - self._set_attr(name, typ, args.pop(0)) - elif name in kwargs: - self._set_attr(name, typ, kwargs.pop(name)) + __proto_fields__: ClassVar[dict[str, ProtoField]] + __proto_debug__: ClassVar[bool] + + def __init__(self, __from_raw: bool = False, /, **kwargs): + undefined_params: list[ProtoField] = [] + for name, field in self.__proto_fields__.items(): + if name in kwargs: + value = kwargs.pop(name) + if __from_raw: + value = _decode(field.type_without_optional, value) + if not check_type(value, field.type): + raise TypeError( + f"'{value}' is not a instance of type '{field.type}'" + ) + setattr(self, name, value) else: - if field.get_default() is not MISSING: - self._set_attr(name, typ, field.get_default()) + if (de := field.get_default()) is not MISSING: + setattr(self, name, de) else: - undefined_params.append(name) + undefined_params.append(field) if undefined_params: raise AttributeError( - f"Undefined parameters in '{self}': {undefined_params}" + f"Missing required parameters: {', '.join(f'{f.name}({f.tag})' for f in undefined_params)}" ) + @classmethod + def _process_field(cls): + fields = {} + + for b in cls.__mro__[-1:0:-1]: + base_fields = getattr(b, "__proto_fields__", None) + if base_fields is not None: + for f in base_fields.values(): + fields[f.name] = f + + cls_annotations = cls.__dict__.get('__annotations__', {}) + cls_fields: list[ProtoField] = [] + for name, typ in cls_annotations.items(): + field = getattr(cls, name, MISSING) + if field is MISSING: + raise TypeError(f'{name!r} should define its proto_field!') + field.ensure_annotation(name, typ) + if field._unevaluated: + _unevaluated_classes.add(cls) + cls_fields.append(field) + + for f in cls_fields: + fields[f.name] = f + if f.default is MISSING: + delattr(cls, f.name) + + for name, value in cls.__dict__.items(): + if isinstance(value, ProtoField) and not name in cls_annotations: + raise TypeError(f'{name!r} is a proto_field but has no type annotation') + + cls.__proto_fields__ = fields + + @classmethod + def update_forwardref(cls, mapping: dict[str, "type[ProtoStruct]"]): + """更新 ForwardRef""" + for field in cls.__proto_fields__.values(): + if not field._unevaluated: + continue + try: + typ = field.type + if isinstance(typ, str): + typ = ForwardRef(typ, is_argument=False, is_class=True) + field.type = typing._eval_type(typ, mapping, mapping) # type: ignore + field._unevaluated = False + except NameError: + pass + def __init_subclass__(cls, **kwargs): - cls._anno_map = cls._get_annotations() - cls._proto_debug = kwargs.pop("debug") if "debug" in kwargs else False + cls.__proto_debug__ = kwargs.pop("debug") if "debug" in kwargs else False + cls._process_field() super().__init_subclass__(**kwargs) def __repr__(self) -> str: attrs = "" - for k, v in self._get_stored_mapping().items(): - attrs += f"{k}={v}, " - return f"{self.__class__.__name__}({attrs[:-2]})" - - def _set_attr(self, name: str, data_typ: type[V], value: V) -> None: - # if get_origin(data_typ) is Union: - # data_typ = (typ for typ in get_args(data_typ) if typ is not NoneType) # type: ignore - if isinstance(data_typ, GenericAlias): # force ignore - pass - elif not isinstance(value, data_typ) and value is not None: - raise TypeError( - f"'{value}' is not a instance of type '{data_typ}'" - ) - setattr(self, name, value) - - @classmethod - def _get_annotations( - cls, - ) -> dict[str, tuple[type[_ProtoTypes], "ProtoField"]]: # Name: (ReturnType, ProtoField) - annotations: dict[str, tuple[type[_ProtoTypes], "ProtoField"]] = {} - for obj in reversed(inspect.getmro(cls)): - if obj in (ProtoStruct, object): # base object, ignore + for k, v in vars(self).items(): + if k.startswith("_"): continue - for name, typ in obj.__annotations__.items(): - if name[0] == "_": # ignore internal var - continue - if not hasattr(obj, name): - raise AttributeError(f"attribute ‘{name}' not defined") - field = getattr(obj, name) # type: ProtoField - - if not isinstance(field, ProtoField): - raise TypeError("attribute '{name}' is not a ProtoField object") - - if hasattr(typ, "__origin__"): - typ = typ.__origin__[typ.__args__[0]] - annotations[name] = (typ, field) - - return annotations - - @classmethod - def _get_field_mapping(cls) -> dict[int, tuple[str, type[_ProtoTypes]]]: # Tag, (Name, Type) - field_mapping: dict[int, tuple[str, type[_ProtoTypes]]] = {} - for name, (typ, field) in cls._anno_map.items(): - field_mapping[field.tag] = (name, typ) - return field_mapping - - def _get_stored_mapping(self) -> dict[str, NT]: - stored_mapping: dict[str, NT] = {} - for name, (_, _) in self._anno_map.items(): - stored_mapping[name] = getattr(self, name) - return stored_mapping - - def _encode(self, v: _ProtoTypes) -> NT: - if isinstance(v, ProtoStruct): - v = v.encode() - return v # type: ignore + attrs += f"{k}={v!r}, " + return f"{self.__class__.__name__}({attrs})" def encode(self) -> bytes: pb_dict: NT = {} - for name, (_, field) in self._anno_map.items(): + + def _encode(v: _ProtoTypes) -> NT: + if isinstance(v, ProtoStruct): + v = v.encode() + return v # type: ignore + + for name, field in self.__proto_fields__.items(): tag = field.tag if tag in pb_dict: raise ValueError(f"duplicate tag: {tag}") value: _ProtoTypes = getattr(self, name) if isinstance(value, list): - pb_dict[tag] = [self._encode(v) for v in value] + pb_dict[tag] = [_encode(v) for v in value] else: - pb_dict[tag] = self._encode(value) + pb_dict[tag] = _encode(value) return proto_encode(cast(Proto, pb_dict)) - @classmethod - def _decode(cls, typ: type[_ProtoTypes], value): - if issubclass(typ, ProtoStruct): - return typ.decode(value) - elif typ is str: - return value.decode(errors="ignore") - elif typ is dict: - return proto_decode(value).proto - elif typ is bool: - return value == 1 - elif typ is list: - if not isinstance(value, list): - return [value] - return value - elif isinstance(typ, GenericAlias): - if typ.__name__.lower() == "list": - real_typ = typ.__args__[0] - ret = [] - if isinstance(value, list): - for v in value: - ret.append(cls._decode(real_typ, v)) - else: - ret.append(cls._decode(real_typ, value)) - return ret - elif isinstance(value, typ): - return value - else: - raise NotImplementedError(f"unknown type '{typ}' and data {value}") - @classmethod def decode(cls, data: bytes) -> Self: if not data: return None # type: ignore pb_dict: Proto = proto_decode(data, 0).proto - mapping = cls._get_field_mapping() - - kwargs = {} - for tag, (name, typ) in mapping.items(): - if tag not in pb_dict: - _, field = cls._anno_map[name] - if field.get_default() is not MISSING: - kwargs[name] = field.get_default() - continue - - raise KeyError(f"tag {tag} not found in '{cls.__name__}'") - kwargs[name] = cls._decode(typ, pb_dict.pop(tag)) - if pb_dict and cls._proto_debug: # unhandled tags - pass - - return cls(**kwargs) + kwargs = { + field.name: pb_dict.pop(field.tag) + for field in cls.__proto_fields__.values() + if field.tag in pb_dict + } + if pb_dict and cls.__proto_debug__: # unhandled tags + pass + return cls(True, **kwargs) + + +def evaluate_all(): + modules = set() + for cls in _unevaluated_classes: + modules.add(cls.__module__) + for base in cls.__mro__[-1:0:-1][2:]: + modules.add(base.__module__) + globalns = {} + for module in modules: + globalns.update(getattr(sys.modules.get(module, None), "__dict__", {})) + for cls in _unevaluated_classes: + cls.update_forwardref(globalns) + _unevaluated_classes.clear() + modules.clear() + globalns.clear() + del modules + del globalns