Skip to content

Commit 081e137

Browse files
committed
schedulers/aws_batch: fix thread local sessions + raise error on missing memory resource
1 parent 90b05b0 commit 081e137

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import threading
3939
from dataclasses import dataclass
4040
from datetime import datetime
41-
from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple
41+
from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple, TypeVar, Callable
4242

4343
import torchx
4444
import yaml
@@ -86,8 +86,8 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]:
8686
reqs.append({"type": "VCPU", "value": str(cpu)})
8787

8888
mem = resource.memMB
89-
if mem <= 0:
90-
mem = 1000
89+
if mem < 0:
90+
raise ValueError(f"AWSBatchScheduler requires memMB to be set to a positive value, got {mem}")
9191
reqs.append({"type": "MEMORY", "value": str(mem)})
9292

9393
if resource.gpu > 0:
@@ -157,17 +157,27 @@ def __repr__(self) -> str:
157157
return str(self)
158158

159159

160-
def _thread_local_session() -> "boto3.session.Session":
161-
KEY = "torchx_aws_batch_session"
162-
local = threading.local()
163-
if hasattr(local, KEY):
164-
# pyre-ignore[16]
165-
return getattr(local, KEY)
160+
T = TypeVar('T')
161+
162+
def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]:
163+
local: threading.local = threading.local()
164+
key: str = "value"
165+
def wrapper() -> T:
166+
if key in local.__dict__:
167+
return local.__dict__[key]
168+
169+
v = f()
170+
local.__dict__[key] = v
171+
return v
172+
173+
return wrapper
174+
175+
176+
@_thread_local_cache
177+
def _local_session() -> "boto3.session.Session":
166178
import boto3.session
167179

168-
session = boto3.session.Session()
169-
setattr(local, KEY, session)
170-
return session
180+
return boto3.session.Session()
171181

172182

173183
class AWSBatchScheduler(Scheduler, DockerWorkspace):
@@ -239,14 +249,14 @@ def __init__(
239249
def _client(self) -> Any:
240250
if self.__client:
241251
return self.__client
242-
return _thread_local_session().client("batch")
252+
return _local_session().client("batch")
243253

244254
@property
245255
# pyre-fixme[3]: Return annotation cannot be `Any`.
246256
def _log_client(self) -> Any:
247257
if self.__log_client:
248258
return self.__log_client
249-
return _thread_local_session().client("logs")
259+
return _local_session().client("logs")
250260

251261
def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
252262
cfg = dryrun_info._cfg

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from contextlib import contextmanager
99
from typing import Generator
1010
from unittest.mock import MagicMock, patch
11+
import threading
1112

1213
import torchx
1314
from torchx import specs
1415
from torchx.schedulers.aws_batch_scheduler import (
1516
create_scheduler,
1617
AWSBatchScheduler,
1718
_role_to_node_properties,
19+
_local_session,
1820
)
1921

2022

@@ -192,6 +194,11 @@ def test_volume_mounts(self) -> None:
192194
mounts=[
193195
specs.VolumeMount(src="efsid", dst_path="/dst", read_only=True),
194196
],
197+
resource=specs.Resource(
198+
cpu=1,
199+
memMB=1000,
200+
gpu=0,
201+
),
195202
)
196203
props = _role_to_node_properties(0, role)
197204
self.assertEqual(
@@ -396,3 +403,16 @@ def test_log_iter(self) -> None:
396403
"foobar",
397404
],
398405
)
406+
407+
def test_local_session(self) -> None:
408+
a: object = _local_session()
409+
self.assertIs(a, _local_session())
410+
411+
def worker() -> None:
412+
b = _local_session()
413+
self.assertIs(b, _local_session())
414+
self.assertIsNot(a, b)
415+
416+
t = threading.Thread(target=worker)
417+
t.start()
418+
t.join()

0 commit comments

Comments
 (0)