diff --git a/torchx/components/serve.py b/torchx/components/serve.py index 3bb5f63b8..e04cfe1df 100644 --- a/torchx/components/serve.py +++ b/torchx/components/serve.py @@ -70,6 +70,7 @@ def torchserve( entrypoint="python", args=args, port_map={"model-download": 8222}, + resource=specs.Resource(cpu=1, gpu=0, memMB=1024), ), ], ) diff --git a/torchx/components/test/serve_test.py b/torchx/components/test/serve_test.py index fcbddd121..6200c5b32 100644 --- a/torchx/components/test/serve_test.py +++ b/torchx/components/test/serve_test.py @@ -31,6 +31,7 @@ def test_torchserve(self) -> None: "1", ], port_map={"model-download": 8222}, + resource=specs.Resource(cpu=1, gpu=0, memMB=1024), ), ], ) diff --git a/torchx/components/utils.py b/torchx/components/utils.py index e3728c44e..cfe95192b 100644 --- a/torchx/components/utils.py +++ b/torchx/components/utils.py @@ -39,6 +39,7 @@ def echo( entrypoint="echo", args=[msg], num_replicas=num_replicas, + resource=specs.Resource(cpu=1, gpu=0, memMB=1024), ) ], ) @@ -62,12 +63,21 @@ def touch(file: str, image: str = torchx.IMAGE) -> specs.AppDef: entrypoint="touch", args=[file], num_replicas=1, + resource=specs.Resource(cpu=1, gpu=0, memMB=1024), ) ], ) -def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.AppDef: +def sh( + *args: str, + image: str = torchx.IMAGE, + num_replicas: int = 1, + cpu: int = 1, + gpu: int = 0, + memMB: int = 1024, + h: Optional[str] = None, +) -> specs.AppDef: """ Runs the provided command via sh. Currently sh does not support environment variable substitution. @@ -76,7 +86,10 @@ def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.Ap args: bash arguments image: image to use num_replicas: number of replicas to run - + cpu: number of cpus per replica + gpu: number of gpus per replica + memMB: cpu memory in MB per replica + h: a registered named resource (if specified takes precedence over cpu, gpu, memMB) """ escaped_args = " ".join(shlex.quote(arg) for arg in args) @@ -90,6 +103,7 @@ def sh(*args: str, image: str = torchx.IMAGE, num_replicas: int = 1) -> specs.Ap entrypoint="sh", args=["-c", escaped_args], num_replicas=num_replicas, + resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h), ) ], ) @@ -102,7 +116,7 @@ def python( script: Optional[str] = None, image: str = torchx.IMAGE, name: str = "torchx_utils_python", - cpu: int = 2, + cpu: int = 1, gpu: int = 0, memMB: int = 1024, h: Optional[str] = None, @@ -164,8 +178,12 @@ def python( def binary( *args: str, entrypoint: str, - name: str = "torchx_utils_python", + name: str = "torchx_utils_binary", num_replicas: int = 1, + cpu: int = 1, + gpu: int = 0, + memMB: int = 1024, + h: Optional[str] = None, ) -> specs.AppDef: """ Test component @@ -174,6 +192,10 @@ def binary( args: arguments passed to the program in sys.argv[1:] (ignored with `--c`) name: name of the job num_replicas: number of copies to run (each on its own container) + cpu: number of cpus per replica + gpu: number of gpus per replica + memMB: cpu memory in MB per replica + h: a registered named resource (if specified takes precedence over cpu, gpu, memMB) :return: """ return specs.AppDef( @@ -184,8 +206,8 @@ def binary( image="", entrypoint=entrypoint, num_replicas=num_replicas, - resource=specs.Resource(cpu=2, gpu=0, memMB=4096), args=[*args], + resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h), ) ], ) @@ -219,6 +241,7 @@ def copy(src: str, dst: str, image: str = torchx.IMAGE) -> specs.AppDef: "--dst", dst, ], + resource=specs.Resource(cpu=1, gpu=0, memMB=1024), ), ], ) @@ -261,6 +284,7 @@ def booth( "--tracker_base", tracker_base, ], + resource=specs.Resource(cpu=1, gpu=0, memMB=1024), ) ], ) diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index 2777023b5..a24af4be1 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -38,7 +38,17 @@ import threading from dataclasses import dataclass from datetime import datetime -from typing import Dict, Iterable, Mapping, Optional, Any, TYPE_CHECKING, Tuple +from typing import ( + Dict, + Iterable, + Mapping, + Optional, + Any, + TYPE_CHECKING, + Tuple, + TypeVar, + Callable, +) import torchx import yaml @@ -86,8 +96,10 @@ def _role_to_node_properties(idx: int, role: Role) -> Dict[str, object]: reqs.append({"type": "VCPU", "value": str(cpu)}) memMB = resource.memMB - if memMB <= 0: - memMB = 1000 + if memMB < 0: + raise ValueError( + f"AWSBatchScheduler requires memMB to be set to a positive value, got {memMB}" + ) reqs.append({"type": "MEMORY", "value": str(memMB)}) if resource.gpu > 0: @@ -162,17 +174,29 @@ def __repr__(self) -> str: return str(self) -def _thread_local_session() -> "boto3.session.Session": - KEY = "torchx_aws_batch_session" - local = threading.local() - if hasattr(local, KEY): - # pyre-ignore[16] - return getattr(local, KEY) +T = TypeVar("T") + + +def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]: + local: threading.local = threading.local() + key: str = "value" + + def wrapper() -> T: + if key in local.__dict__: + return local.__dict__[key] + + v = f() + local.__dict__[key] = v + return v + + return wrapper + + +@_thread_local_cache +def _local_session() -> "boto3.session.Session": import boto3.session - session = boto3.session.Session() - setattr(local, KEY, session) - return session + return boto3.session.Session() class AWSBatchScheduler(Scheduler, DockerWorkspace): @@ -244,14 +268,14 @@ def __init__( def _client(self) -> Any: if self.__client: return self.__client - return _thread_local_session().client("batch") + return _local_session().client("batch") @property # pyre-fixme[3]: Return annotation cannot be `Any`. def _log_client(self) -> Any: if self.__log_client: return self.__log_client - return _thread_local_session().client("logs") + return _local_session().client("logs") def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str: cfg = dryrun_info._cfg diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index b54f90cf8..50736b533 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import threading import unittest from contextlib import contextmanager from typing import Generator @@ -15,6 +16,7 @@ create_scheduler, AWSBatchScheduler, _role_to_node_properties, + _local_session, ) @@ -198,6 +200,11 @@ def test_volume_mounts(self) -> None: mounts=[ specs.VolumeMount(src="efsid", dst_path="/dst", read_only=True), ], + resource=specs.Resource( + cpu=1, + memMB=1000, + gpu=0, + ), ) props = _role_to_node_properties(0, role) self.assertEqual( @@ -402,3 +409,16 @@ def test_log_iter(self) -> None: "foobar", ], ) + + def test_local_session(self) -> None: + a: object = _local_session() + self.assertIs(a, _local_session()) + + def worker() -> None: + b = _local_session() + self.assertIs(b, _local_session()) + self.assertIsNot(a, b) + + t = threading.Thread(target=worker) + t.start() + t.join()