Skip to content

slurm_scheduler: inherit cwd instead of image + skip mem request via cfg #372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
"TIMEOUT": AppState.FAILED,
}

SBATCH_OPTIONS = {
"partition",
"time",
}


def _apply_app_id_env(s: str) -> str:
"""
Expand All @@ -68,7 +73,6 @@ class SlurmReplicaRequest:
"""

name: str
dir: str
entrypoint: str
args: List[str]
srun_opts: Dict[str, str]
Expand All @@ -79,21 +83,25 @@ class SlurmReplicaRequest:
def from_role(
cls, name: str, role: Role, cfg: Mapping[str, CfgVal]
) -> "SlurmReplicaRequest":
sbatch_opts = {k: str(v) for k, v in cfg.items() if v is not None}
sbatch_opts = {}
for k, v in cfg.items():
if v is None:
continue
if k in SBATCH_OPTIONS:
sbatch_opts[k] = str(v)
sbatch_opts.setdefault("ntasks-per-node", "1")
resource = role.resource

if resource != NONE:
if resource.cpu > 0:
sbatch_opts.setdefault("cpus-per-task", str(resource.cpu))
if resource.memMB > 0:
if not cfg.get("nomem") and resource.memMB > 0:
sbatch_opts.setdefault("mem", str(resource.memMB))
if resource.gpu > 0:
sbatch_opts.setdefault("gpus-per-task", str(resource.gpu))

return cls(
name=name,
dir=role.image,
entrypoint=role.entrypoint,
args=list(role.args),
sbatch_opts=sbatch_opts,
Expand All @@ -109,11 +117,9 @@ def materialize(self) -> Tuple[List[str], List[str]]:
sbatch_args = [
f"--job-name={self.name}",
] + [f"--{key}={value}" for key, value in self.sbatch_opts.items()]
srun_args = (
[f"--chdir={self.dir}"]
+ [f"--{key}={value}" for key, value in self.srun_opts.items()]
+ [f"--export={key}={value}" for key, value in self.env.items()]
)
srun_args = [f"--{key}={value}" for key, value in self.srun_opts.items()] + [
f"--export={key}={value}" for key, value in self.env.items()
]

srun_group = srun_args + [self.entrypoint] + self.args
srun_group = [_apply_app_id_env(arg) for arg in srun_group]
Expand Down Expand Up @@ -172,10 +178,14 @@ class SlurmScheduler(Scheduler):

Logs are written to the default slurm log file.

Any scheduler options passed to it are added as SBATCH arguments to each
Some of the config options passed to it are added as SBATCH arguments to each
replica. See https://slurm.schedmd.com/sbatch.html#SECTION_OPTIONS for info
on the arguments.

Slurm jobs inherit the currently active ``conda`` or ``virtualenv`` and run
in the current working directory. This matches the behavior of the
``local_cwd`` scheduler.

For more info see:

* https://slurm.schedmd.com/sbatch.html
Expand Down Expand Up @@ -219,6 +229,12 @@ def run_opts(self) -> runopts:
default=None,
help="The maximum time the job is allowed to run for.",
)
opts.add(
"nomem",
type_=bool,
default=False,
help="disables memory request to workaround https://github.com/aws/aws-parallelcluster/issues/2198",
)
return opts

def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
Expand Down
22 changes: 18 additions & 4 deletions torchx/schedulers/test/slurm_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,21 @@ def test_replica_request(self) -> None:
)
self.assertEqual(
srun,
["--chdir=/some/path", "--export=FOO=bar", "echo", "'hello slurm'", "test"],
["--export=FOO=bar", "echo", "'hello slurm'", "test"],
)

# test nomem option
sbatch, srun = SlurmReplicaRequest.from_role(
"role-name", role, cfg={"nomem": True}
).materialize()
self.assertEqual(
sbatch,
[
"--job-name=role-name",
"--ntasks-per-node=1",
"--cpus-per-task=2",
"--gpus-per-task=3",
],
)

def test_replica_request_app_id(self) -> None:
Expand Down Expand Up @@ -135,9 +149,9 @@ def test_dryrun_multi_role(self) -> None:
# exit on error
set -e

srun --chdir=/some/path echo 0 'hello '"$SLURM_JOB_ID"'' :\\
--chdir=/some/path echo 1 'hello '"$SLURM_JOB_ID"'' :\\
--chdir=/some/path echo
srun echo 0 'hello '"$SLURM_JOB_ID"'' :\\
echo 1 'hello '"$SLURM_JOB_ID"'' :\\
echo
""",
)

Expand Down