Skip to content

schedulers/aws_batch: fix thread local sessions + raise error on missing memory resource #430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchx/components/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def torchserve(
entrypoint="python",
args=args,
port_map={"model-download": 8222},
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
),
],
)
1 change: 1 addition & 0 deletions torchx/components/test/serve_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_torchserve(self) -> None:
"1",
],
port_map={"model-download": 8222},
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
),
],
)
Expand Down
34 changes: 29 additions & 5 deletions torchx/components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def echo(
entrypoint="echo",
args=[msg],
num_replicas=num_replicas,
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
)
],
)
Expand All @@ -62,12 +63,21 @@ def touch(file: str, image: str = torchx.IMAGE) -> specs.AppDef:
entrypoint="touch",
args=[file],
num_replicas=1,
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
)
],
)


def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.AppDef:
def sh(
*args: str,
image: str = torchx.IMAGE,
num_replicas: int = 1,
cpu: int = 1,
gpu: int = 0,
memMB: int = 1024,
h: Optional[str] = None,
) -> specs.AppDef:
"""
Runs the provided command via sh. Currently sh does not support
environment variable substitution.
Expand All @@ -76,7 +86,10 @@ def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.Ap
args: bash arguments
image: image to use
num_replicas: number of replicas to run

cpu: number of cpus per replica
gpu: number of gpus per replica
memMB: cpu memory in MB per replica
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
"""

escaped_args = " ".join(shlex.quote(arg) for arg in args)
Expand All @@ -90,6 +103,7 @@ def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.Ap
entrypoint="sh",
args=["-c", escaped_args],
num_replicas=num_replicas,
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
)
],
)
Expand All @@ -102,7 +116,7 @@ def python(
script: Optional[str] = None,
image: str = torchx.IMAGE,
name: str = "torchx_utils_python",
cpu: int = 2,
cpu: int = 1,
gpu: int = 0,
memMB: int = 1024,
h: Optional[str] = None,
Expand Down Expand Up @@ -164,8 +178,12 @@ def python(
def binary(
*args: str,
entrypoint: str,
name: str = "torchx_utils_python",
name: str = "torchx_utils_binary",
num_replicas: int = 1,
cpu: int = 1,
gpu: int = 0,
memMB: int = 1024,
h: Optional[str] = None,
) -> specs.AppDef:
"""
Test component
Expand All @@ -174,6 +192,10 @@ def binary(
args: arguments passed to the program in sys.argv[1:] (ignored with `--c`)
name: name of the job
num_replicas: number of copies to run (each on its own container)
cpu: number of cpus per replica
gpu: number of gpus per replica
memMB: cpu memory in MB per replica
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
:return:
"""
return specs.AppDef(
Expand All @@ -184,8 +206,8 @@ def binary(
image="<NONE>",
entrypoint=entrypoint,
num_replicas=num_replicas,
resource=specs.Resource(cpu=2, gpu=0, memMB=4096),
args=[*args],
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
)
],
)
Expand Down Expand Up @@ -219,6 +241,7 @@ def copy(src: str, dst: str, image: str = torchx.IMAGE) -> specs.AppDef:
"--dst",
dst,
],
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
),
],
)
Expand Down Expand Up @@ -261,6 +284,7 @@ def booth(
"--tracker_base",
tracker_base,
],
resource=specs.Resource(cpu=1, gpu=0, memMB=1024),
)
],
)
52 changes: 38 additions & 14 deletions torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,17 @@
import threading
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple
from typing import (
Dict,
Iterable,
Mapping,
Optional,
Any,
TYPE_CHECKING,
Tuple,
TypeVar,
Callable,
)

import torchx
import yaml
Expand Down Expand Up @@ -86,8 +96,10 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]:
reqs.append({"type": "VCPU", "value": str(cpu)})

memMB = resource.memMB
if memMB <= 0:
memMB = 1000
if memMB < 0:
raise ValueError(
f"AWSBatchScheduler requires memMB to be set to a positive value, got {memMB}"
)
reqs.append({"type": "MEMORY", "value": str(memMB)})

if resource.gpu > 0:
Expand Down Expand Up @@ -162,17 +174,29 @@ def __repr__(self) -> str:
return str(self)


def _thread_local_session() -> "boto3.session.Session":
KEY = "torchx_aws_batch_session"
local = threading.local()
if hasattr(local, KEY):
# pyre-ignore[16]
return getattr(local, KEY)
T = TypeVar("T")


def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]:
local: threading.local = threading.local()
key: str = "value"

def wrapper() -> T:
if key in local.__dict__:
return local.__dict__[key]

v = f()
local.__dict__[key] = v
return v

return wrapper


@_thread_local_cache
def _local_session() -> "boto3.session.Session":
import boto3.session

session = boto3.session.Session()
setattr(local, KEY, session)
return session
return boto3.session.Session()


class AWSBatchScheduler(Scheduler, DockerWorkspace):
Expand Down Expand Up @@ -244,14 +268,14 @@ def __init__(
def _client(self) -> Any:
if self.__client:
return self.__client
return _thread_local_session().client("batch")
return _local_session().client("batch")

@property
# pyre-fixme[3]: Return annotation cannot be `Any`.
def _log_client(self) -> Any:
if self.__log_client:
return self.__log_client
return _thread_local_session().client("logs")
return _local_session().client("logs")

def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
cfg = dryrun_info._cfg
Expand Down
20 changes: 20 additions & 0 deletions torchx/schedulers/test/aws_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import threading
import unittest
from contextlib import contextmanager
from typing import Generator
Expand All @@ -15,6 +16,7 @@
create_scheduler,
AWSBatchScheduler,
_role_to_node_properties,
_local_session,
)


Expand Down Expand Up @@ -198,6 +200,11 @@ def test_volume_mounts(self) -> None:
mounts=[
specs.VolumeMount(src="efsid", dst_path="/dst", read_only=True),
],
resource=specs.Resource(
cpu=1,
memMB=1000,
gpu=0,
),
)
props = _role_to_node_properties(0, role)
self.assertEqual(
Expand Down Expand Up @@ -402,3 +409,16 @@ def test_log_iter(self) -> None:
"foobar",
],
)

def test_local_session(self) -> None:
a: object = _local_session()
self.assertIs(a, _local_session())

def worker() -> None:
b = _local_session()
self.assertIs(b, _local_session())
self.assertIsNot(a, b)

t = threading.Thread(target=worker)
t.start()
t.join()