Skip to content

Commit 98fb9cc

Browse files
d4l3kfacebook-github-bot
authored andcommitted
schedulers/aws_batch: fix thread local sessions + raise error on missing memory resource (#430)
Summary: This behavior was noticed in #429 and this is intended to clean it up. Previously the thread local logic was incorrect so we would create a new session for every request which caused a lot of spam: ``` torchx 2022-03-21 14:19:45 INFO Found credentials in environment variables. torchx 2022-03-21 14:19:45 INFO Found credentials in environment variables. torchx 2022-03-21 14:19:45 INFO Found credentials in environment variables. ``` Pull Request resolved: #430 Test Plan: Updated unit tests ``` torchx run --scheduler aws_batch --wait --log dist.ddp --memMB 2000 -j 1x1 --script large-shm.py ``` Reviewed By: aivanou Differential Revision: D35027238 Pulled By: d4l3k fbshipit-source-id: f28024ac2b1ee789d389021ec0c8c668d5d8514d
1 parent 0a66255 commit 98fb9cc

File tree

5 files changed

+89
-19
lines changed

5 files changed

+89
-19
lines changed

torchx/components/serve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def torchserve(
7070
entrypoint="python",
7171
args=args,
7272
port_map={"model-download": 8222},
73+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
7374
),
7475
],
7576
)

torchx/components/test/serve_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_torchserve(self) -> None:
3131
"1",
3232
],
3333
port_map={"model-download": 8222},
34+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
3435
),
3536
],
3637
)

torchx/components/utils.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def echo(
3939
entrypoint="echo",
4040
args=[msg],
4141
num_replicas=num_replicas,
42+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
4243
)
4344
],
4445
)
@@ -62,12 +63,21 @@ def touch(file: str, image: str = torchx.IMAGE) -> specs.AppDef:
6263
entrypoint="touch",
6364
args=[file],
6465
num_replicas=1,
66+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
6567
)
6668
],
6769
)
6870

6971

70-
def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.AppDef:
72+
def sh(
73+
*args: str,
74+
image: str = torchx.IMAGE,
75+
num_replicas: int = 1,
76+
cpu: int = 1,
77+
gpu: int = 0,
78+
memMB: int = 1024,
79+
h: Optional[str] = None,
80+
) -> specs.AppDef:
7181
"""
7282
Runs the provided command via sh. Currently sh does not support
7383
environment variable substitution.
@@ -76,7 +86,10 @@ def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.Ap
7686
args: bash arguments
7787
image: image to use
7888
num_replicas: number of replicas to run
79-
89+
cpu: number of cpus per replica
90+
gpu: number of gpus per replica
91+
memMB: cpu memory in MB per replica
92+
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
8093
"""
8194

8295
escaped_args = " ".join(shlex.quote(arg) for arg in args)
@@ -90,6 +103,7 @@ def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.Ap
90103
entrypoint="sh",
91104
args=["-c", escaped_args],
92105
num_replicas=num_replicas,
106+
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
93107
)
94108
],
95109
)
@@ -102,7 +116,7 @@ def python(
102116
script: Optional[str] = None,
103117
image: str = torchx.IMAGE,
104118
name: str = "torchx_utils_python",
105-
cpu: int = 2,
119+
cpu: int = 1,
106120
gpu: int = 0,
107121
memMB: int = 1024,
108122
h: Optional[str] = None,
@@ -164,8 +178,12 @@ def python(
164178
def binary(
165179
*args: str,
166180
entrypoint: str,
167-
name: str = "torchx_utils_python",
181+
name: str = "torchx_utils_binary",
168182
num_replicas: int = 1,
183+
cpu: int = 1,
184+
gpu: int = 0,
185+
memMB: int = 1024,
186+
h: Optional[str] = None,
169187
) -> specs.AppDef:
170188
"""
171189
Test component
@@ -174,6 +192,10 @@ def binary(
174192
args: arguments passed to the program in sys.argv[1:] (ignored with `--c`)
175193
name: name of the job
176194
num_replicas: number of copies to run (each on its own container)
195+
cpu: number of cpus per replica
196+
gpu: number of gpus per replica
197+
memMB: cpu memory in MB per replica
198+
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
177199
:return:
178200
"""
179201
return specs.AppDef(
@@ -184,8 +206,8 @@ def binary(
184206
image="<NONE>",
185207
entrypoint=entrypoint,
186208
num_replicas=num_replicas,
187-
resource=specs.Resource(cpu=2, gpu=0, memMB=4096),
188209
args=[*args],
210+
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
189211
)
190212
],
191213
)
@@ -219,6 +241,7 @@ def copy(src: str, dst: str, image: str = torchx.IMAGE) -> specs.AppDef:
219241
"--dst",
220242
dst,
221243
],
244+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
222245
),
223246
],
224247
)
@@ -261,6 +284,7 @@ def booth(
261284
"--tracker_base",
262285
tracker_base,
263286
],
287+
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
264288
)
265289
],
266290
)

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,17 @@
3838
import threading
3939
from dataclasses import dataclass
4040
from datetime import datetime
41-
from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple
41+
from typing import (
42+
Dict,
43+
Iterable,
44+
Mapping,
45+
Optional,
46+
Any,
47+
TYPE_CHECKING,
48+
Tuple,
49+
TypeVar,
50+
Callable,
51+
)
4252

4353
import torchx
4454
import yaml
@@ -86,8 +96,10 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]:
8696
reqs.append({"type": "VCPU", "value": str(cpu)})
8797

8898
memMB = resource.memMB
89-
if memMB <= 0:
90-
memMB = 1000
99+
if memMB < 0:
100+
raise ValueError(
101+
f"AWSBatchScheduler requires memMB to be set to a positive value, got {memMB}"
102+
)
91103
reqs.append({"type": "MEMORY", "value": str(memMB)})
92104

93105
if resource.gpu > 0:
@@ -162,17 +174,29 @@ def __repr__(self) -> str:
162174
return str(self)
163175

164176

165-
def _thread_local_session() -> "boto3.session.Session":
166-
KEY = "torchx_aws_batch_session"
167-
local = threading.local()
168-
if hasattr(local, KEY):
169-
# pyre-ignore[16]
170-
return getattr(local, KEY)
177+
T = TypeVar("T")
178+
179+
180+
def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]:
181+
local: threading.local = threading.local()
182+
key: str = "value"
183+
184+
def wrapper() -> T:
185+
if key in local.__dict__:
186+
return local.__dict__[key]
187+
188+
v = f()
189+
local.__dict__[key] = v
190+
return v
191+
192+
return wrapper
193+
194+
195+
@_thread_local_cache
196+
def _local_session() -> "boto3.session.Session":
171197
import boto3.session
172198

173-
session = boto3.session.Session()
174-
setattr(local, KEY, session)
175-
return session
199+
return boto3.session.Session()
176200

177201

178202
class AWSBatchScheduler(Scheduler, DockerWorkspace):
@@ -244,14 +268,14 @@ def __init__(
244268
def _client(self) -> Any:
245269
if self.__client:
246270
return self.__client
247-
return _thread_local_session().client("batch")
271+
return _local_session().client("batch")
248272

249273
@property
250274
# pyre-fixme[3]: Return annotation cannot be `Any`.
251275
def _log_client(self) -> Any:
252276
if self.__log_client:
253277
return self.__log_client
254-
return _thread_local_session().client("logs")
278+
return _local_session().client("logs")
255279

256280
def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
257281
cfg = dryrun_info._cfg

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import threading
78
import unittest
89
from contextlib import contextmanager
910
from typing import Generator
@@ -15,6 +16,7 @@
1516
create_scheduler,
1617
AWSBatchScheduler,
1718
_role_to_node_properties,
19+
_local_session,
1820
)
1921

2022

@@ -198,6 +200,11 @@ def test_volume_mounts(self) -> None:
198200
mounts=[
199201
specs.VolumeMount(src="efsid", dst_path="/dst", read_only=True),
200202
],
203+
resource=specs.Resource(
204+
cpu=1,
205+
memMB=1000,
206+
gpu=0,
207+
),
201208
)
202209
props = _role_to_node_properties(0, role)
203210
self.assertEqual(
@@ -402,3 +409,16 @@ def test_log_iter(self) -> None:
402409
"foobar",
403410
],
404411
)
412+
413+
def test_local_session(self) -> None:
414+
a: object = _local_session()
415+
self.assertIs(a, _local_session())
416+
417+
def worker() -> None:
418+
b = _local_session()
419+
self.assertIs(b, _local_session())
420+
self.assertIsNot(a, b)
421+
422+
t = threading.Thread(target=worker)
423+
t.start()
424+
t.join()

0 commit comments

Comments
 (0)