Skip to content

Edge Case Fixes #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
TYPE_CHECKING,
)

import asyncstdlib as a
from bt_decode import MetadataV15, PortableRegistry, decode as decode_by_type_string
from scalecodec.base import ScaleBytes, ScaleType, RuntimeConfigurationObject
from scalecodec.types import (
Expand Down Expand Up @@ -58,7 +57,7 @@
get_next_id,
rng as random,
)
from async_substrate_interface.utils.cache import async_sql_lru_cache
from async_substrate_interface.utils.cache import async_sql_lru_cache, CachedFetcher
from async_substrate_interface.utils.decoding import (
_determine_if_old_runtime_call,
_bt_decode_to_dict_or_list,
Expand Down Expand Up @@ -539,14 +538,14 @@ def __init__(
"You are instantiating the AsyncSubstrateInterface Websocket outside of an event loop. "
"Verify this is intended."
)
now = asyncio.new_event_loop().time()
now = 0.0
self.last_received = now
self.last_sent = now
self._in_use_ids = set()

async def __aenter__(self):
async with self._lock:
self._in_use += 1
await self.connect()
self._in_use += 1
await self.connect()
return self

@staticmethod
Expand All @@ -559,18 +558,19 @@ async def connect(self, force=False):
self.last_sent = now
if self._exit_task:
self._exit_task.cancel()
if not self._initialized or force:
self._initialized = True
try:
self._receiving_task.cancel()
await self._receiving_task
await self.ws.close()
except (AttributeError, asyncio.CancelledError):
pass
self.ws = await asyncio.wait_for(
connect(self.ws_url, **self._options), timeout=10
)
self._receiving_task = asyncio.create_task(self._start_receiving())
async with self._lock:
if not self._initialized or force:
try:
self._receiving_task.cancel()
await self._receiving_task
await self.ws.close()
except (AttributeError, asyncio.CancelledError):
pass
self.ws = await asyncio.wait_for(
connect(self.ws_url, **self._options), timeout=10
)
self._receiving_task = asyncio.create_task(self._start_receiving())
self._initialized = True

async def __aexit__(self, exc_type, exc_val, exc_tb):
async with self._lock: # TODO is this actually what I want to happen?
Expand Down Expand Up @@ -619,6 +619,7 @@ async def _recv(self) -> None:
self._open_subscriptions -= 1
if "id" in response:
self._received[response["id"]] = response
self._in_use_ids.remove(response["id"])
elif "params" in response:
self._received[response["params"]["subscription"]] = response
else:
Expand Down Expand Up @@ -649,6 +650,9 @@ async def send(self, payload: dict) -> int:
id: the internal ID of the request (incremented int)
"""
original_id = get_next_id()
while original_id in self._in_use_ids:
original_id = get_next_id()
self._in_use_ids.add(original_id)
# self._open_subscriptions += 1
await self.max_subscriptions.acquire()
try:
Expand All @@ -674,7 +678,7 @@ async def retrieve(self, item_id: int) -> Optional[dict]:
self.max_subscriptions.release()
return item
except KeyError:
await asyncio.sleep(0.001)
await asyncio.sleep(0.1)
return None


Expand Down Expand Up @@ -748,6 +752,12 @@ def __init__(
self.registry_type_map = {}
self.type_id_to_name = {}
self._mock = _mock
self._block_hash_fetcher = CachedFetcher(512, self._get_block_hash)
self._parent_hash_fetcher = CachedFetcher(512, self._get_parent_block_hash)
self._runtime_info_fetcher = CachedFetcher(16, self._get_block_runtime_info)
self._runtime_version_for_fetcher = CachedFetcher(
512, self._get_block_runtime_version_for
)

async def __aenter__(self):
if not self._mock:
Expand Down Expand Up @@ -1869,9 +1879,8 @@ async def get_metadata(self, block_hash=None) -> MetadataV15:

return runtime.metadata_v15

@a.lru_cache(maxsize=512)
async def get_parent_block_hash(self, block_hash):
return await self._get_parent_block_hash(block_hash)
return await self._parent_hash_fetcher.execute(block_hash)

async def _get_parent_block_hash(self, block_hash):
block_header = await self.rpc_request("chain_getHeader", [block_hash])
Expand Down Expand Up @@ -1916,9 +1925,8 @@ async def get_storage_by_key(self, block_hash: str, storage_key: str) -> Any:
"Unknown error occurred during retrieval of events"
)

@a.lru_cache(maxsize=16)
async def get_block_runtime_info(self, block_hash: str) -> dict:
return await self._get_block_runtime_info(block_hash)
return await self._runtime_info_fetcher.execute(block_hash)

get_block_runtime_version = get_block_runtime_info

Expand All @@ -1929,9 +1937,8 @@ async def _get_block_runtime_info(self, block_hash: str) -> dict:
response = await self.rpc_request("state_getRuntimeVersion", [block_hash])
return response.get("result")

@a.lru_cache(maxsize=512)
async def get_block_runtime_version_for(self, block_hash: str):
return await self._get_block_runtime_version_for(block_hash)
return await self._runtime_version_for_fetcher.execute(block_hash)

async def _get_block_runtime_version_for(self, block_hash: str):
"""
Expand Down Expand Up @@ -2149,14 +2156,14 @@ async def _make_rpc_request(
and current_time - self.ws.last_sent >= self.retry_timeout
):
if attempt >= self.max_retries:
logger.warning(
logger.error(
f"Timed out waiting for RPC requests {attempt} times. Exiting."
)
raise MaxRetriesExceeded("Max retries reached.")
else:
self.ws.last_received = time.time()
await self.ws.connect(force=True)
logger.error(
logger.warning(
f"Timed out waiting for RPC requests. "
f"Retrying attempt {attempt + 1} of {self.max_retries}"
)
Expand Down Expand Up @@ -2240,9 +2247,8 @@ async def rpc_request(
else:
raise SubstrateRequestException(result[payload_id][0])

@a.lru_cache(maxsize=512)
async def get_block_hash(self, block_id: int) -> str:
return await self._get_block_hash(block_id)
return await self._block_hash_fetcher.execute(block_id)

async def _get_block_hash(self, block_id: int) -> str:
return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"]
Expand Down
56 changes: 56 additions & 0 deletions async_substrate_interface/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import asyncio
from collections import OrderedDict
import functools
import os
import pickle
import sqlite3
from pathlib import Path
from typing import Callable, Any

import asyncstdlib as a


USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False
CACHE_LOCATION = (
os.path.expanduser(
Expand Down Expand Up @@ -139,3 +144,54 @@ async def inner(self, *args, **kwargs):
return inner

return decorator


class LRUCache:
def __init__(self, max_size: int):
self.max_size = max_size
self.cache = OrderedDict()

def set(self, key, value):
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.max_size:
self.cache.popitem(last=False)

def get(self, key):
if key in self.cache:
# Mark as recently used
self.cache.move_to_end(key)
return self.cache[key]
return None


class CachedFetcher:
def __init__(self, max_size: int, method: Callable):
self._inflight: dict[int, asyncio.Future] = {}
self._method = method
self._cache = LRUCache(max_size=max_size)

async def execute(self, single_arg: Any) -> str:
if item := self._cache.get(single_arg):
return item

if single_arg in self._inflight:
result = await self._inflight[single_arg]
return result

loop = asyncio.get_running_loop()
future = loop.create_future()
self._inflight[single_arg] = future

try:
result = await self._method(single_arg)
self._cache.set(single_arg, result)
future.set_result(result)
return result
except Exception as e:
# Propagate errors
future.set_exception(e)
raise
finally:
self._inflight.pop(single_arg, None)
Loading