Skip to content

Commit 533bd8f

Browse files
authored
Merge pull request #453 from nats-io/aiofiles
Use executor when writing/reading files from object store
2 parents b51f391 + 899d715 commit 533bd8f

File tree

7 files changed

+604
-421
lines changed

7 files changed

+604
-421
lines changed

Pipfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ pytest = "*"
1010
pytest-cov = "*"
1111
yapf = "*"
1212
toml = "*" # see https://github.com/google/yapf/issues/936
13+
exceptiongroup = "*"
1314

1415
[packages]
1516
nkeys = "*"
1617
aiohttp = "*"
18+
fast-mail-parser = "*"

Pipfile.lock

Lines changed: 455 additions & 372 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

nats/aio/client.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2016-2022 The NATS Authors
1+
# Copyright 2016-2023 The NATS Authors
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# you may not use this file except in compliance with the License.
44
# You may obtain a copy of the License at
@@ -21,6 +21,7 @@
2121
import logging
2222
import ssl
2323
import time
24+
import string
2425
from dataclasses import dataclass
2526
from email.parser import BytesParser
2627
from random import shuffle
@@ -1636,6 +1637,15 @@ async def _process_headers(self, headers) -> dict[str, str] | None:
16361637
hdr.update(parsed_hdr)
16371638
else:
16381639
hdr = parsed_hdr
1640+
1641+
if parse_email:
1642+
to_delete = []
1643+
for k in hdr.keys():
1644+
if any(c in k for c in string.whitespace):
1645+
to_delete.append(k)
1646+
for k in to_delete:
1647+
del hdr[k]
1648+
16391649
except Exception as e:
16401650
await self._error_cb(e)
16411651
return hdr

nats/js/object_store.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,14 @@
2525
import nats.errors
2626
from nats.js import api
2727
from nats.js.errors import (
28-
BadObjectMetaError, DigestMismatchError, InvalidObjectNameError,
29-
ObjectAlreadyExists, ObjectDeletedError, ObjectNotFoundError,
30-
NotFoundError, LinkIsABucketError
28+
BadObjectMetaError, DigestMismatchError, ObjectAlreadyExists,
29+
ObjectDeletedError, ObjectNotFoundError, NotFoundError, LinkIsABucketError
3130
)
3231
from nats.js.kv import MSG_ROLLUP_SUBJECT
3332

3433
VALID_BUCKET_RE = re.compile(r"^[a-zA-Z0-9_-]+$")
3534
VALID_KEY_RE = re.compile(r"^[-/_=\.a-zA-Z0-9]+$")
3635

37-
38-
def key_valid(key: str) -> bool:
39-
if len(key) == 0 or key[0] == '.' or key[-1] == '.':
40-
return False
41-
return VALID_KEY_RE.match(key) is not None
42-
43-
4436
if TYPE_CHECKING:
4537
from nats.js import JetStreamContext
4638

@@ -136,10 +128,6 @@ def __init__(
136128
self._stream = stream
137129
self._js = js
138130

139-
def __sanitize_name(self, name: str) -> str:
140-
name = name.replace(".", "_")
141-
return name.replace(" ", "_")
142-
143131
async def get_info(
144132
self,
145133
name: str,
@@ -148,10 +136,7 @@ async def get_info(
148136
"""
149137
get_info will retrieve the current information for the object.
150138
"""
151-
obj = self.__sanitize_name(name)
152-
153-
if not key_valid(obj):
154-
raise InvalidObjectNameError
139+
obj = name
155140

156141
meta = OBJ_META_PRE_TEMPLATE.format(
157142
bucket=self._name,
@@ -185,15 +170,13 @@ async def get_info(
185170
async def get(
186171
self,
187172
name: str,
173+
writeinto: Optional[io.BufferedIOBase] = None,
188174
show_deleted: Optional[bool] = False,
189175
) -> ObjectResult:
190176
"""
191177
get will pull the object from the underlying stream.
192178
"""
193-
obj = self.__sanitize_name(name)
194-
195-
if not key_valid(obj):
196-
raise InvalidObjectNameError
179+
obj = name
197180

198181
# Grab meta info.
199182
info = await self.get_info(obj, show_deleted)
@@ -222,10 +205,22 @@ async def get(
222205

223206
h = sha256()
224207

208+
executor = None
209+
executor_fn = None
210+
if writeinto:
211+
executor = asyncio.get_running_loop().run_in_executor
212+
if hasattr(writeinto, 'buffer'):
213+
executor_fn = writeinto.buffer.write
214+
else:
215+
executor_fn = writeinto.write
216+
225217
async for msg in sub._message_iterator:
226218
tokens = msg._get_metadata_fields(msg.reply)
227219

228-
result.data += msg.data
220+
if executor:
221+
await executor(None, executor_fn, msg.data)
222+
else:
223+
result.data += msg.data
229224
h.update(msg.data)
230225

231226
# Check if we are done.
@@ -262,11 +257,7 @@ async def put(
262257
max_chunk_size=OBJ_DEFAULT_CHUNK_SIZE,
263258
)
264259

265-
obj = self.__sanitize_name(meta.name)
266-
267-
if not key_valid(obj):
268-
raise InvalidObjectNameError
269-
260+
obj = meta.name
270261
einfo = None
271262

272263
# Create the new nuid so chunks go on a new subject if the name is re-used.
@@ -285,14 +276,18 @@ async def put(
285276
pass
286277

287278
# Normalize based on type but treat all as readers.
288-
# FIXME: Need an async based reader as well.
279+
executor = None
289280
if isinstance(data, str):
290281
data = io.BytesIO(data.encode())
291282
elif isinstance(data, bytes):
292283
data = io.BytesIO(data)
293-
elif (not isinstance(data, io.BufferedIOBase)):
294-
# Only allowing buffered readers at the moment.
295-
raise TypeError("nats: subtype of io.BufferedIOBase was expected")
284+
elif hasattr(data, 'readinto') or isinstance(data, io.BufferedIOBase):
285+
# Need to delegate to a threaded executor to avoid blocking.
286+
executor = asyncio.get_running_loop().run_in_executor
287+
elif hasattr(data, 'buffer') or isinstance(data, io.TextIOWrapper):
288+
data = data.buffer
289+
else:
290+
raise TypeError("nats: invalid type for object store")
296291

297292
info = api.ObjectInfo(
298293
name=meta.name,
@@ -312,7 +307,12 @@ async def put(
312307

313308
while True:
314309
try:
315-
n = data.readinto(chunk)
310+
n = None
311+
if executor:
312+
n = await executor(None, data.readinto, chunk)
313+
else:
314+
n = data.readinto(chunk)
315+
316316
if n == 0:
317317
break
318318
payload = chunk[:n]
@@ -506,10 +506,7 @@ async def delete(self, name: str) -> ObjectResult:
506506
"""
507507
delete will delete the object.
508508
"""
509-
obj = self.__sanitize_name(name)
510-
511-
if not key_valid(obj):
512-
raise InvalidObjectNameError
509+
obj = name
513510

514511
# Grab meta info.
515512
info = await self.get_info(obj)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
name="nats-py",
77
extras_require={
88
'nkeys': ['nkeys'],
9+
'aiohttp': ['aiohttp'],
910
'fast_parse': ['fast-mail-parser'],
1011
}
1112
)

tests/test_js.py

Lines changed: 101 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,6 @@ async def test_fetch_headers(self):
643643
assert msg.header['AAA-AAA-AAA'] == 'a'
644644
assert msg.header['AAA-BBB-AAA'] == ''
645645

646-
# FIXME: An unprocessable key makes the rest of the header be invalid.
647646
await js.publish(
648647
"test.nats.1",
649648
b'third_msg',
@@ -654,7 +653,7 @@ async def test_fetch_headers(self):
654653
}
655654
)
656655
msgs = await sub.fetch(1)
657-
assert msgs[0].header == None
656+
assert msgs[0].header == {'AAA-BBB-AAA': 'b'}
658657

659658
msg = await js.get_msg("test-nats", 4)
660659
assert msg.header == None
@@ -2875,17 +2874,21 @@ async def error_handler(e):
28752874
tmp.write(ls.encode())
28762875
tmp.close()
28772876

2878-
with pytest.raises(TypeError):
2879-
with open(tmp.name) as f:
2880-
info = await obs.put("tmp", f)
2881-
28822877
with open(tmp.name) as f:
28832878
info = await obs.put("tmp", f.buffer)
28842879
assert info.name == "tmp"
28852880
assert info.size == 1048609
28862881
assert info.chunks == 9
28872882
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='
28882883

2884+
with open(tmp.name) as f:
2885+
info = await obs.put("tmp2", f)
2886+
assert info.name == "tmp2"
2887+
assert info.size == 1048609
2888+
assert info.chunks == 9
2889+
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='
2890+
2891+
# By default this reads the complete data.
28892892
obr = await obs.get("tmp")
28902893
info = obr.info
28912894
assert info.name == "tmp"
@@ -2900,9 +2903,31 @@ async def error_handler(e):
29002903
assert info.chunks == 1
29012904

29022905
# Using a local file but not as a buffered reader.
2903-
with pytest.raises(TypeError):
2904-
with open("pyproject.toml") as f:
2905-
await obs.put("pyproject", f)
2906+
with open("pyproject.toml") as f:
2907+
info = await obs.put("pyproject2", f)
2908+
assert info.name == "pyproject2"
2909+
assert info.chunks == 1
2910+
2911+
# Write into file without getting complete data.
2912+
w = tempfile.NamedTemporaryFile(delete=False)
2913+
w.close()
2914+
with open(w.name, 'w') as f:
2915+
obr = await obs.get("tmp", writeinto=f)
2916+
assert obr.data == b''
2917+
assert obr.info.size == 1048609
2918+
assert obr.info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='
2919+
2920+
w2 = tempfile.NamedTemporaryFile(delete=False)
2921+
w2.close()
2922+
with open(w2.name, 'w') as f:
2923+
obr = await obs.get("tmp", writeinto=f.buffer)
2924+
assert obr.data == b''
2925+
assert obr.info.size == 1048609
2926+
assert obr.info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='
2927+
2928+
with open(w2.name) as f:
2929+
result = f.read(-1)
2930+
assert len(result) == 1048609
29062931

29072932
await nc.close()
29082933

@@ -2947,7 +2972,7 @@ async def error_handler(e):
29472972
with pytest.raises(nats.js.errors.NotFoundError):
29482973
await obs.get("tmp")
29492974

2950-
with pytest.raises(nats.js.errors.InvalidObjectNameError):
2975+
with pytest.raises(nats.js.errors.NotFoundError):
29512976
await obs.get("")
29522977

29532978
res = await js.delete_object_store(bucket="sample")
@@ -3151,6 +3176,72 @@ async def error_handler(e):
31513176

31523177
await nc.close()
31533178

3179+
@async_test
3180+
async def test_object_aiofiles(self):
3181+
try:
3182+
import aiofiles
3183+
except ImportError:
3184+
pytest.skip("aiofiles not installed")
3185+
3186+
errors = []
3187+
3188+
async def error_handler(e):
3189+
print("Error:", e, type(e))
3190+
errors.append(e)
3191+
3192+
nc = await nats.connect(error_cb=error_handler)
3193+
js = nc.jetstream()
3194+
3195+
# Create an 8MB object.
3196+
obs = await js.create_object_store(bucket="big")
3197+
ls = ''.join("A" for _ in range(0, 1 * 1024 * 1024 + 33))
3198+
w = io.BytesIO(ls.encode())
3199+
info = await obs.put("big", w)
3200+
assert info.name == "big"
3201+
assert info.size == 1048609
3202+
assert info.chunks == 9
3203+
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='
3204+
3205+
# Create actual file and put it in a bucket.
3206+
tmp = tempfile.NamedTemporaryFile(delete=False)
3207+
tmp.write(ls.encode())
3208+
tmp.close()
3209+
3210+
async with aiofiles.open(tmp.name) as f:
3211+
info = await obs.put("tmp", f)
3212+
assert info.name == "tmp"
3213+
assert info.size == 1048609
3214+
assert info.chunks == 9
3215+
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='
3216+
3217+
async with aiofiles.open(tmp.name) as f:
3218+
info = await obs.put("tmp2", f)
3219+
assert info.name == "tmp2"
3220+
assert info.size == 1048609
3221+
assert info.chunks == 9
3222+
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='
3223+
3224+
obr = await obs.get("tmp")
3225+
info = obr.info
3226+
assert info.name == "tmp"
3227+
assert info.size == 1048609
3228+
assert len(obr.data) == info.size # no reader reads whole file.
3229+
assert info.chunks == 9
3230+
assert info.digest == 'SHA-256=mhT1pLyi9JlIaqwVmvt0wQp2x09kor_80Lirl4SDblA='
3231+
3232+
# Using a local file.
3233+
async with aiofiles.open("pyproject.toml") as f:
3234+
info = await obs.put("pyproject", f.buffer)
3235+
assert info.name == "pyproject"
3236+
assert info.chunks == 1
3237+
3238+
async with aiofiles.open("pyproject.toml") as f:
3239+
info = await obs.put("pyproject2", f)
3240+
assert info.name == "pyproject2"
3241+
assert info.chunks == 1
3242+
3243+
await nc.close()
3244+
31543245

31553246
class ConsumerReplicasTest(SingleJetStreamServerTestCase):
31563247

tests/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,6 @@ def setUp(self):
522522
server = NATSD(
523523
port=4555, config_file=get_config_file("conf/no_auth_user.conf")
524524
)
525-
server.debug = True
526525
self.server_pool.append(server)
527526
for natsd in self.server_pool:
528527
start_natsd(natsd)

0 commit comments

Comments
 (0)