diff --git a/scripts/slurmtest.sh b/scripts/slurmtest.sh index 0f2fe347f..f2f556b3c 100755 --- a/scripts/slurmtest.sh +++ b/scripts/slurmtest.sh @@ -18,7 +18,7 @@ source "$VENV"/bin/activate python --version pip install "$REMOTE_WHEEL" -APP_ID="$(torchx run --wait --scheduler slurm --scheduler_args partition=compute,time=10 utils.echo --num_replicas 3)" +APP_ID="$(torchx run --wait --scheduler slurm --scheduler_args partition=compute,time=10,comment=hello utils.echo --num_replicas 3)" torchx status "$APP_ID" torchx describe "$APP_ID" sacct -j "$(basename "$APP_ID")" diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index 695bf9794..19ca645b2 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -54,9 +54,15 @@ "TIMEOUT": AppState.FAILED, } -SBATCH_OPTIONS = { +SBATCH_JOB_OPTIONS = { + "comment", + "mail-user", + "mail-type", +} +SBATCH_GROUP_OPTIONS = { "partition", "time", + "constraint", } @@ -90,7 +96,7 @@ def from_role( for k, v in cfg.items(): if v is None: continue - if k in SBATCH_OPTIONS: + if k in SBATCH_GROUP_OPTIONS: sbatch_opts[k] = str(v) sbatch_opts.setdefault("ntasks-per-node", "1") resource = role.resource @@ -271,6 +277,26 @@ def run_opts(self) -> runopts: default=False, help="disables memory request to workaround https://github.com/aws/aws-parallelcluster/issues/2198", ) + opts.add( + "comment", + type_=str, + help="Comment to set on the slurm job.", + ) + opts.add( + "constraint", + type_=str, + help="Constraint to use for the slurm job.", + ) + opts.add( + "mail-user", + type_=str, + help="User to mail on job end.", + ) + opts.add( + "mail-type", + type_=str, + help="What events to mail users on.", + ) return opts def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: @@ -301,8 +327,14 @@ def _submit_dryrun( name = f"{role.name}-{replica_id}" replica_role = values.apply(role) replicas[name] = SlurmReplicaRequest.from_role(name, replica_role, cfg) + cmd = ["sbatch", "--parsable"] + + for k in SBATCH_JOB_OPTIONS: + if k in cfg and cfg[k] is not None: + cmd += [f"--{k}={cfg[k]}"] + req = SlurmBatchRequest( - cmd=["sbatch", "--parsable"], + cmd=cmd, replicas=replicas, ) return AppDryRunInfo(req, repr) diff --git a/torchx/schedulers/test/slurm_scheduler_test.py b/torchx/schedulers/test/slurm_scheduler_test.py index f645db215..c6b5e790a 100644 --- a/torchx/schedulers/test/slurm_scheduler_test.py +++ b/torchx/schedulers/test/slurm_scheduler_test.py @@ -34,27 +34,51 @@ def tmp_cwd() -> Generator[None, None, None]: os.chdir(cwd) +def simple_role() -> specs.Role: + return specs.Role( + name="foo", + image="/some/path", + entrypoint="echo", + args=["hello slurm", "test"], + env={ + "FOO": "bar", + }, + num_replicas=5, + resource=specs.Resource( + cpu=2, + memMB=10, + gpu=3, + ), + ) + + +def simple_app() -> specs.AppDef: + return specs.AppDef( + name="foo", + roles=[ + specs.Role( + name="a", + image="/some/path", + entrypoint="echo", + args=[specs.macros.replica_id, f"hello {specs.macros.app_id}"], + num_replicas=2, + ), + specs.Role( + name="b", + image="/some/path", + entrypoint="echo", + ), + ], + ) + + class SlurmSchedulerTest(unittest.TestCase): def test_create_scheduler(self) -> None: scheduler = create_scheduler("foo") self.assertIsInstance(scheduler, SlurmScheduler) def test_replica_request(self) -> None: - role = specs.Role( - name="foo", - image="/some/path", - entrypoint="echo", - args=["hello slurm", "test"], - env={ - "FOO": "bar", - }, - num_replicas=5, - resource=specs.Resource( - cpu=2, - memMB=10, - gpu=3, - ), - ) + role = simple_role() sbatch, srun = SlurmReplicaRequest.from_role( "role-0", role, cfg={} ).materialize() @@ -79,9 +103,9 @@ def test_replica_request(self) -> None: ], ) - # test nomem option + def test_replica_request_nomem(self) -> None: sbatch, srun = SlurmReplicaRequest.from_role( - "role-name", role, cfg={"nomem": True} + "role-name", simple_role(), cfg={"nomem": True} ).materialize() self.assertEqual( sbatch, @@ -93,6 +117,15 @@ def test_replica_request(self) -> None: ], ) + def test_replica_request_constraint(self) -> None: + sbatch, srun = SlurmReplicaRequest.from_role( + "role-name", simple_role(), cfg={"constraint": "orange"} + ).materialize() + self.assertIn( + "--constraint=orange", + sbatch, + ) + def test_replica_request_app_id(self) -> None: role = specs.Role( name="foo", @@ -132,23 +165,7 @@ def test_replica_request_run_config(self) -> None: def test_dryrun_multi_role(self) -> None: scheduler = create_scheduler("foo") - app = specs.AppDef( - name="foo", - roles=[ - specs.Role( - name="a", - image="/some/path", - entrypoint="echo", - args=[specs.macros.replica_id, f"hello {specs.macros.app_id}"], - num_replicas=2, - ), - specs.Role( - name="b", - image="/some/path", - entrypoint="echo", - ), - ], - ) + app = simple_app() info = scheduler.submit_dryrun(app, cfg={}) req = info.request self.assertIsInstance(req, SlurmBatchRequest) @@ -344,3 +361,36 @@ def test_log_iter(self, run: MagicMock) -> None: ) ) self.assertEqual(logs, ["hello", "world"]) + + def test_dryrun_comment(self) -> None: + scheduler = create_scheduler("foo") + app = simple_app() + info = scheduler.submit_dryrun( + app, + cfg={ + "comment": "banana foo bar", + }, + ) + self.assertIn( + "--comment=banana foo bar", + info.request.cmd, + ) + + def test_dryrun_mail(self) -> None: + scheduler = create_scheduler("foo") + app = simple_app() + info = scheduler.submit_dryrun( + app, + cfg={ + "mail-user": "foo@bar.com", + "mail-type": "END", + }, + ) + self.assertIn( + "--mail-user=foo@bar.com", + info.request.cmd, + ) + self.assertIn( + "--mail-type=END", + info.request.cmd, + )