|
38 | 38 | import threading
|
39 | 39 | from dataclasses import dataclass
|
40 | 40 | from datetime import datetime
|
41 |
| -from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple |
| 41 | +from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple, TypeVar, Callable |
42 | 42 |
|
43 | 43 | import torchx
|
44 | 44 | import yaml
|
@@ -86,8 +86,8 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]:
|
86 | 86 | reqs.append({"type": "VCPU", "value": str(cpu)})
|
87 | 87 |
|
88 | 88 | mem = resource.memMB
|
89 |
| - if mem <= 0: |
90 |
| - mem = 1000 |
| 89 | + if mem < 0: |
| 90 | + raise ValueError(f"AWSBatchScheduler requires memMB to be set to a positive value, got {mem}") |
91 | 91 | reqs.append({"type": "MEMORY", "value": str(mem)})
|
92 | 92 |
|
93 | 93 | if resource.gpu > 0:
|
@@ -157,17 +157,27 @@ def __repr__(self) -> str:
|
157 | 157 | return str(self)
|
158 | 158 |
|
159 | 159 |
|
160 |
| -def _thread_local_session() -> "boto3.session.Session": |
161 |
| - KEY = "torchx_aws_batch_session" |
162 |
| - local = threading.local() |
163 |
| - if hasattr(local, KEY): |
164 |
| - # pyre-ignore[16] |
165 |
| - return getattr(local, KEY) |
| 160 | +T = TypeVar('T') |
| 161 | + |
| 162 | +def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]: |
| 163 | + local: threading.local = threading.local() |
| 164 | + key: str = "value" |
| 165 | + def wrapper() -> T: |
| 166 | + if key in local.__dict__: |
| 167 | + return local.__dict__[key] |
| 168 | + |
| 169 | + v = f() |
| 170 | + local.__dict__[key] = v |
| 171 | + return v |
| 172 | + |
| 173 | + return wrapper |
| 174 | + |
| 175 | + |
| 176 | +@_thread_local_cache |
| 177 | +def _local_session() -> "boto3.session.Session": |
166 | 178 | import boto3.session
|
167 | 179 |
|
168 |
| - session = boto3.session.Session() |
169 |
| - setattr(local, KEY, session) |
170 |
| - return session |
| 180 | + return boto3.session.Session() |
171 | 181 |
|
172 | 182 |
|
173 | 183 | class AWSBatchScheduler(Scheduler, DockerWorkspace):
|
@@ -239,14 +249,14 @@ def __init__(
|
239 | 249 | def _client(self) -> Any:
|
240 | 250 | if self.__client:
|
241 | 251 | return self.__client
|
242 |
| - return _thread_local_session().client("batch") |
| 252 | + return _local_session().client("batch") |
243 | 253 |
|
244 | 254 | @property
|
245 | 255 | # pyre-fixme[3]: Return annotation cannot be `Any`.
|
246 | 256 | def _log_client(self) -> Any:
|
247 | 257 | if self.__log_client:
|
248 | 258 | return self.__log_client
|
249 |
| - return _thread_local_session().client("logs") |
| 259 | + return _local_session().client("logs") |
250 | 260 |
|
251 | 261 | def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
|
252 | 262 | cfg = dryrun_info._cfg
|
|
0 commit comments