Skip to content

Commit 4591ebf

Browse files
committed
schedulers/aws_batch: fix thread local sessions + raise error on missing memory resource
1 parent 90b05b0 commit 4591ebf

File tree

3 files changed

+87
-19
lines changed

3 files changed

+87
-19
lines changed

torchx/components/utils.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def echo(
3939
entrypoint="echo",
4040
args=[msg],
4141
num_replicas=num_replicas,
42+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
4243
)
4344
],
4445
)
@@ -62,12 +63,21 @@ def touch(file: str, image: str = torchx.IMAGE) -> specs.AppDef:
6263
entrypoint="touch",
6364
args=[file],
6465
num_replicas=1,
66+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
6567
)
6668
],
6769
)
6870

6971

70-
def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.AppDef:
72+
def sh(
73+
*args: str,
74+
image: str = torchx.IMAGE,
75+
num_replicas: int = 1,
76+
cpu: int = 1,
77+
gpu: int = 0,
78+
memMB: int = 1024,
79+
h: Optional[str] = None,
80+
) -> specs.AppDef:
7181
"""
7282
Runs the provided command via sh. Currently sh does not support
7383
environment variable substitution.
@@ -76,7 +86,10 @@ def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.Ap
7686
args: bash arguments
7787
image: image to use
7888
num_replicas: number of replicas to run
79-
89+
cpu: number of cpus per replica
90+
gpu: number of gpus per replica
91+
memMB: cpu memory in MB per replica
92+
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
8093
"""
8194

8295
escaped_args = " ".join(shlex.quote(arg) for arg in args)
@@ -90,6 +103,7 @@ def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.Ap
90103
entrypoint="sh",
91104
args=["-c", escaped_args],
92105
num_replicas=num_replicas,
106+
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
93107
)
94108
],
95109
)
@@ -102,7 +116,7 @@ def python(
102116
script: Optional[str] = None,
103117
image: str = torchx.IMAGE,
104118
name: str = "torchx_utils_python",
105-
cpu: int = 2,
119+
cpu: int = 1,
106120
gpu: int = 0,
107121
memMB: int = 1024,
108122
h: Optional[str] = None,
@@ -164,8 +178,12 @@ def python(
164178
def binary(
165179
*args: str,
166180
entrypoint: str,
167-
name: str = "torchx_utils_python",
181+
name: str = "torchx_utils_binary",
168182
num_replicas: int = 1,
183+
cpu: int = 1,
184+
gpu: int = 0,
185+
memMB: int = 1024,
186+
h: Optional[str] = None,
169187
) -> specs.AppDef:
170188
"""
171189
Test component
@@ -174,6 +192,10 @@ def binary(
174192
args: arguments passed to the program in sys.argv[1:] (ignored with `--c`)
175193
name: name of the job
176194
num_replicas: number of copies to run (each on its own container)
195+
cpu: number of cpus per replica
196+
gpu: number of gpus per replica
197+
memMB: cpu memory in MB per replica
198+
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
177199
:return:
178200
"""
179201
return specs.AppDef(
@@ -184,8 +206,8 @@ def binary(
184206
image="<NONE>",
185207
entrypoint=entrypoint,
186208
num_replicas=num_replicas,
187-
resource=specs.Resource(cpu=2, gpu=0, memMB=4096),
188209
args=[*args],
210+
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
189211
)
190212
],
191213
)
@@ -219,6 +241,7 @@ def copy(src: str, dst: str, image: str = torchx.IMAGE) -> specs.AppDef:
219241
"--dst",
220242
dst,
221243
],
244+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
222245
),
223246
],
224247
)
@@ -261,6 +284,7 @@ def booth(
261284
"--tracker_base",
262285
tracker_base,
263286
],
287+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
264288
)
265289
],
266290
)

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,17 @@
3838
import threading
3939
from dataclasses import dataclass
4040
from datetime import datetime
41-
from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple
41+
from typing import (
42+
Dict,
43+
Iterable,
44+
Mapping,
45+
Optional,
46+
Any,
47+
TYPE_CHECKING,
48+
Tuple,
49+
TypeVar,
50+
Callable,
51+
)
4252

4353
import torchx
4454
import yaml
@@ -86,8 +96,10 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]:
8696
reqs.append({"type": "VCPU", "value": str(cpu)})
8797

8898
mem = resource.memMB
89-
if mem <= 0:
90-
mem = 1000
99+
if mem < 0:
100+
raise ValueError(
101+
f"AWSBatchScheduler requires memMB to be set to a positive value, got {mem}"
102+
)
91103
reqs.append({"type": "MEMORY", "value": str(mem)})
92104

93105
if resource.gpu > 0:
@@ -157,17 +169,29 @@ def __repr__(self) -> str:
157169
return str(self)
158170

159171

160-
def _thread_local_session() -> "boto3.session.Session":
161-
KEY = "torchx_aws_batch_session"
162-
local = threading.local()
163-
if hasattr(local, KEY):
164-
# pyre-ignore[16]
165-
return getattr(local, KEY)
172+
T = TypeVar("T")
173+
174+
175+
def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]:
176+
local: threading.local = threading.local()
177+
key: str = "value"
178+
179+
def wrapper() -> T:
180+
if key in local.__dict__:
181+
return local.__dict__[key]
182+
183+
v = f()
184+
local.__dict__[key] = v
185+
return v
186+
187+
return wrapper
188+
189+
190+
@_thread_local_cache
191+
def _local_session() -> "boto3.session.Session":
166192
import boto3.session
167193

168-
session = boto3.session.Session()
169-
setattr(local, KEY, session)
170-
return session
194+
return boto3.session.Session()
171195

172196

173197
class AWSBatchScheduler(Scheduler, DockerWorkspace):
@@ -239,14 +263,14 @@ def __init__(
239263
def _client(self) -> Any:
240264
if self.__client:
241265
return self.__client
242-
return _thread_local_session().client("batch")
266+
return _local_session().client("batch")
243267

244268
@property
245269
# pyre-fixme[3]: Return annotation cannot be `Any`.
246270
def _log_client(self) -> Any:
247271
if self.__log_client:
248272
return self.__log_client
249-
return _thread_local_session().client("logs")
273+
return _local_session().client("logs")
250274

251275
def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
252276
cfg = dryrun_info._cfg

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import threading
78
import unittest
89
from contextlib import contextmanager
910
from typing import Generator
@@ -15,6 +16,7 @@
1516
create_scheduler,
1617
AWSBatchScheduler,
1718
_role_to_node_properties,
19+
_local_session,
1820
)
1921

2022

@@ -192,6 +194,11 @@ def test_volume_mounts(self) -> None:
192194
mounts=[
193195
specs.VolumeMount(src="efsid", dst_path="/dst", read_only=True),
194196
],
197+
resource=specs.Resource(
198+
cpu=1,
199+
memMB=1000,
200+
gpu=0,
201+
),
195202
)
196203
props = _role_to_node_properties(0, role)
197204
self.assertEqual(
@@ -396,3 +403,16 @@ def test_log_iter(self) -> None:
396403
"foobar",
397404
],
398405
)
406+
407+
def test_local_session(self) -> None:
408+
a: object = _local_session()
409+
self.assertIs(a, _local_session())
410+
411+
def worker() -> None:
412+
b = _local_session()
413+
self.assertIs(b, _local_session())
414+
self.assertIsNot(a, b)
415+
416+
t = threading.Thread(target=worker)
417+
t.start()
418+
t.join()

0 commit comments

Comments
 (0)