diff --git a/python/thirdweb-ai/src/thirdweb_ai/common/utils.py b/python/thirdweb-ai/src/thirdweb_ai/common/utils.py index 2b57b3e..6ce2b48 100644 --- a/python/thirdweb-ai/src/thirdweb_ai/common/utils.py +++ b/python/thirdweb-ai/src/thirdweb_ai/common/utils.py @@ -77,3 +77,52 @@ def filter_response_keys(items: list[dict[str, Any]], keys_to_keep: list[str] | for key in keys_to_remove: item.pop(key, None) return items + + +# Aggregation function validation +VALID_EVENT_AGGREGATIONS = [ + "count()", + "countDistinct(address)", + "countDistinct(contract_address)", + "countDistinct(transaction_hash)", + "min(block_number)", + "max(block_number)", + "min(block_timestamp)", + "max(block_timestamp)", +] + +VALID_TRANSACTION_AGGREGATIONS = [ + "count()", + "countDistinct(from_address)", + "countDistinct(to_address)", + "countDistinct(contract_address)", + "sum(value)", + "avg(value)", + "min(value)", + "max(value)", + "min(block_number)", + "max(block_number)", + "min(block_timestamp)", + "max(block_timestamp)", +] + + +def validate_aggregation(agg: str, valid_aggregations: list[str]) -> str: + """Validate an aggregation function string.""" + # Handle aliases like "count() as event_count" + base_agg = agg.split(" as ")[0].strip() + + if base_agg not in valid_aggregations: + raise ValueError(f"Invalid aggregation function: {base_agg}. Valid options: {valid_aggregations}") + + return agg + + +def validate_event_aggregation(agg: str) -> str: + """Validate an event aggregation function.""" + return validate_aggregation(agg, VALID_EVENT_AGGREGATIONS) + + +def validate_transaction_aggregation(agg: str) -> str: + """Validate a transaction aggregation function.""" + return validate_aggregation(agg, VALID_TRANSACTION_AGGREGATIONS) diff --git a/python/thirdweb-ai/src/thirdweb_ai/services/insight.py b/python/thirdweb-ai/src/thirdweb_ai/services/insight.py index 94f6cba..bb5c18f 100644 --- a/python/thirdweb-ai/src/thirdweb_ai/services/insight.py +++ b/python/thirdweb-ai/src/thirdweb_ai/services/insight.py @@ -6,7 +6,16 @@ validate_signature, validate_transaction_hash, ) -from thirdweb_ai.common.utils import EVENT_KEYS_TO_KEEP, TRANSACTION_KEYS_TO_KEEP, clean_resolve, filter_response_keys +from thirdweb_ai.common.utils import ( + EVENT_KEYS_TO_KEEP, + TRANSACTION_KEYS_TO_KEEP, + VALID_EVENT_AGGREGATIONS, + VALID_TRANSACTION_AGGREGATIONS, + clean_resolve, + filter_response_keys, + validate_event_aggregation, + validate_transaction_aggregation, +) from thirdweb_ai.services.service import Service from thirdweb_ai.tools.tool import tool @@ -409,3 +418,163 @@ def decode_signature( signature = validate_signature(signature) out = self._get(f"resolve/{signature}", params) return clean_resolve(out) + + @tool( + description="Aggregate blockchain events with powerful grouping and aggregation options. Use this to get event counts, sums, or other aggregations grouped by fields like address, block, or time period." + ) + def aggregate_events( + self, + aggregate: Annotated[ + list[str], + f"Aggregation functions to apply. Valid options: {', '.join(VALID_EVENT_AGGREGATIONS)}. Can include aliases like 'count() as event_count'.", + ], + group_by: Annotated[ + Literal["address", "contract_address", "from_address", "block_number", "event_signature"] | None, + "Field to group events by for aggregation. Required when using aggregate functions.", + ] = None, + chain_id: Annotated[ + list[int] | int | None, + "Chain ID(s) to query (e.g., 1 for Ethereum Mainnet, 137 for Polygon). Specify multiple IDs as a list for cross-chain queries (max 5).", + ] = None, + contract_address: Annotated[ + str | None, + "Contract address to filter events by (e.g., '0x1234...'). Only return events emitted by this contract.", + ] = None, + transaction_hash: Annotated[ + str | None, + "Specific transaction hash to filter events by (e.g., '0xabc123...'). Useful for examining events in a particular transaction.", + ] = None, + event_signature: Annotated[ + str | None, + "Event signature to filter by (human-readable, e.g., 'Transfer(address,address,uint256)').", + ] = None, + limit: Annotated[ + int | None, + "Number of aggregated results to return (default 20, max 100).", + ] = 20, + page: Annotated[ + int | None, + "Page number for paginated results, starting from 0. 20 results are returned per page.", + ] = None, + ) -> dict[str, Any]: + # Validate aggregation functions + validated_aggregate = [validate_event_aggregation(agg) for agg in aggregate] + + params: dict[str, Any] = { + "sort_by": "block_number", + "sort_order": "desc", + "decode": True, + "aggregate": validated_aggregate, + } + + if group_by: + params["group_by"] = "address" if group_by == "contract_address" else group_by + + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids + if contract_address: + params["filter_address"] = validate_address(contract_address) + if transaction_hash: + params["filter_transaction_hash"] = validate_transaction_hash(transaction_hash) + if event_signature: + params["filter_event_signature"] = event_signature + if limit: + params["limit"] = limit + if page: + params["page"] = page + + out = self._get("events", params) + + # Clean up response by removing chain_id from aggregations if present + if out.get("aggregations"): + for agg in out["aggregations"]: + if isinstance(agg, dict): + for value in agg.values(): + if isinstance(value, dict) and "chain_id" in value: + value.pop("chain_id") + + return out + + @tool( + description="Aggregate blockchain transactions with powerful grouping and aggregation options. Use this for transaction counts, volumes, sums, and other analytics grouped by address, block, or time period." + ) + def aggregate_transactions( + self, + aggregate: Annotated[ + list[str], + f"Aggregation functions to apply. Valid options: {', '.join(VALID_TRANSACTION_AGGREGATIONS)}. Can include aliases like 'count() as tx_count' or 'sum(value) as total_value'.", + ], + group_by: Annotated[ + Literal["from_address", "to_address", "block_number"] | None, + "Field to group transactions by for aggregation. Required when using aggregate functions.", + ] = None, + chain_id: Annotated[ + list[int] | int | None, + "Chain ID(s) to query (e.g., 1 for Ethereum, 137 for Polygon). Specify multiple IDs as a list for cross-chain queries.", + ] = None, + from_address: Annotated[ + str | None, + "Filter transactions sent from this address (e.g., '0x1234...'). Useful for tracking outgoing transactions from a wallet.", + ] = None, + to_address: Annotated[ + str | None, + "Filter transactions sent to this address (e.g., '0x1234...'). Useful for tracking incoming transactions to a contract or wallet.", + ] = None, + function_signature: Annotated[ + str | None, + "Function signature to filter by (e.g., 'approve(address,uint256)').", + ] = None, + value_above: Annotated[ + int | None, + "Filter for transactions with value above this amount (in wei - base blockchain units).", + ] = None, + limit: Annotated[ + int | None, + "Number of aggregated results to return (default 20, max 100).", + ] = 20, + page: Annotated[ + int | None, + "Page number for paginated results, starting from 0. 20 results are returned per page.", + ] = None, + ) -> dict[str, Any]: + # Validate aggregation functions + validated_aggregate = [validate_transaction_aggregation(agg) for agg in aggregate] + + params: dict[str, Any] = { + "sort_by": "block_number", + "sort_order": "desc", + "decode": True, + "aggregate": validated_aggregate, + } + + if group_by: + params["group_by"] = group_by + + chain_ids = chain_id if chain_id is not None else self.chain_ids + if chain_ids: + params["chain"] = chain_ids + if from_address: + params["filter_from_address"] = validate_address(from_address) + if to_address: + params["filter_to_address"] = validate_address(to_address) + if function_signature: + params["filter_function_signature"] = function_signature + if value_above: + params["value_gte"] = value_above + if limit: + params["limit"] = limit + if page: + params["page"] = page + + out = self._get("transactions", params) + + # Clean up response by removing chain_id from aggregations if present + if out.get("aggregations"): + for agg in out["aggregations"]: + if isinstance(agg, dict): + for value in agg.values(): + if isinstance(value, dict) and "chain_id" in value: + value.pop("chain_id") + + return out