15
15
import shlex
16
16
import subprocess
17
17
import tempfile
18
+ import warnings
18
19
from dataclasses import dataclass
19
- from typing import Any , Dict , List , Mapping , Optional , Tuple
20
+ from datetime import datetime
21
+ from typing import Any , Dict , List , Mapping , Optional , Tuple , Iterable
20
22
21
- from torchx .schedulers .api import AppDryRunInfo , DescribeAppResponse , Scheduler
23
+ from torchx .schedulers .api import AppDryRunInfo , DescribeAppResponse , Scheduler , Stream
24
+ from torchx .schedulers .local_scheduler import LogIterator
22
25
from torchx .specs import (
23
26
NONE ,
24
27
AppDef ,
@@ -100,26 +103,41 @@ def from_role(
100
103
if resource .gpu > 0 :
101
104
sbatch_opts .setdefault ("gpus-per-task" , str (resource .gpu ))
102
105
106
+ srun_opts = {
107
+ "output" : f"slurm-{ macros .app_id } -{ name } .out" ,
108
+ }
109
+
103
110
return cls (
104
111
name = name ,
105
112
entrypoint = role .entrypoint ,
106
113
args = list (role .args ),
107
114
sbatch_opts = sbatch_opts ,
108
- srun_opts = {} ,
115
+ srun_opts = srun_opts ,
109
116
env = dict (role .env ),
110
117
)
111
118
119
+ def _opts_to_strs (self , opts : Dict [str , str ]) -> List [str ]:
120
+ out = []
121
+ for key , value in opts .items ():
122
+ if value is not None :
123
+ out .append (f"--{ key } ={ value } " )
124
+ else :
125
+ out .append (f"--{ key } " )
126
+ return out
127
+
112
128
def materialize (self ) -> Tuple [List [str ], List [str ]]:
113
129
"""
114
130
materialize returns the sbatch and srun groups for this role. They
115
131
should be combined using `:` per slurm heterogenous groups.
116
132
"""
117
133
sbatch_args = [
118
134
f"--job-name={ self .name } " ,
119
- ] + [f"--{ key } ={ value } " for key , value in self .sbatch_opts .items ()]
120
- srun_args = [f"--{ key } ={ value } " for key , value in self .srun_opts .items ()] + [
121
- f"--export={ key } ={ value } " for key , value in self .env .items ()
122
- ]
135
+ ] + self ._opts_to_strs (self .sbatch_opts )
136
+ srun_args = self ._opts_to_strs (self .srun_opts )
137
+
138
+ if len (self .env ) > 0 :
139
+ kvs = [f"{ key } ={ value } " for key , value in self .env .items ()]
140
+ srun_args += ["--export=ALL," + "," .join (kvs )]
123
141
124
142
srun_group = srun_args + [self .entrypoint ] + self .args
125
143
srun_group = [_apply_app_id_env (arg ) for arg in srun_group ]
@@ -160,6 +178,9 @@ def materialize(self) -> str:
160
178
# exit on error
161
179
set -e
162
180
181
+ export PYTHONUNBUFFERED=1
182
+ export SLURM_UNBUFFEREDIO=1
183
+
163
184
srun { " " .join (srun_groups )}
164
185
"""
165
186
sbatch_cmd = self .cmd + sbatch_groups
@@ -176,7 +197,11 @@ class SlurmScheduler(Scheduler):
176
197
resource allocations and args and then sbatch is used to launch all of them
177
198
together.
178
199
179
- Logs are written to the default slurm log file.
200
+ Logs are available in combined form via ``torchx log``, the programmatic API
201
+ as well as in the job launch directory as
202
+ ``slurm-<jobid>-<role>-<replica_id>.out``. If TorchX is running in a
203
+ different directory than where the job was created the logs won't be able to
204
+ be found.
180
205
181
206
Some of the config options passed to it are added as SBATCH arguments to each
182
207
replica. See https://slurm.schedmd.com/sbatch.html#SECTION_OPTIONS for info
@@ -203,9 +228,7 @@ class SlurmScheduler(Scheduler):
203
228
type: scheduler
204
229
features:
205
230
cancel: true
206
- logs: |
207
- Logs are accessible via the default slurm log file but not the
208
- programmatic API.
231
+ logs: true
209
232
distributed: true
210
233
describe: |
211
234
Partial support. SlurmScheduler will return job and replica
@@ -262,7 +285,7 @@ def _submit_dryrun(
262
285
app_id = macros .app_id ,
263
286
replica_id = str (replica_id ),
264
287
)
265
- name = f"{ app . name } - { role .name } -{ replica_id } "
288
+ name = f"{ role .name } -{ replica_id } "
266
289
replica_role = values .apply (role )
267
290
replicas [name ] = SlurmReplicaRequest .from_role (name , replica_role , cfg )
268
291
req = SlurmBatchRequest (
@@ -308,19 +331,19 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
308
331
), f"failed to translate slurm state { state } to torchx state"
309
332
app_state = state_enum
310
333
311
- name_parts = row ["JobName" ].split ("-" )
312
- if len ( name_parts ) < 3 :
334
+ role , _ , replica_id = row ["JobName" ].rpartition ("-" )
335
+ if not replica_id or not role :
313
336
# name should always have at least 3 parts but sometimes sacct
314
337
# is slow to update
315
338
continue
316
- role = name_parts [- 2 ]
317
- replica_id = int (name_parts [- 1 ])
318
339
if role not in roles :
319
340
roles [role ] = Role (name = role , num_replicas = 0 , image = "" )
320
341
roles_statuses [role ] = RoleStatus (role , [])
321
342
roles [role ].num_replicas += 1
322
343
roles_statuses [role ].replicas .append (
323
- ReplicaStatus (id = replica_id , role = role , state = app_state , hostname = "" ),
344
+ ReplicaStatus (
345
+ id = int (replica_id ), role = role , state = app_state , hostname = ""
346
+ ),
324
347
)
325
348
326
349
return DescribeAppResponse (
@@ -331,6 +354,34 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
331
354
msg = msg ,
332
355
)
333
356
357
+ def log_iter (
358
+ self ,
359
+ app_id : str ,
360
+ role_name : str ,
361
+ k : int = 0 ,
362
+ regex : Optional [str ] = None ,
363
+ since : Optional [datetime ] = None ,
364
+ until : Optional [datetime ] = None ,
365
+ should_tail : bool = False ,
366
+ streams : Optional [Stream ] = None ,
367
+ ) -> Iterable [str ]:
368
+ if since or until :
369
+ warnings .warn (
370
+ "since and/or until times specified for SlurmScheduler.log_iter."
371
+ " These will be ignored and all log lines will be returned"
372
+ )
373
+ if streams is not None and streams != Stream .COMBINED :
374
+ warnings .warn (
375
+ "streams specified for SlurmScheduler.log_iter."
376
+ " These will be ignored and all log lines will be returned"
377
+ )
378
+
379
+ log_file = f"slurm-{ app_id } -{ role_name } -{ k } .out"
380
+
381
+ return LogIterator (
382
+ app_id , regex or ".*" , log_file , self , should_tail = should_tail
383
+ )
384
+
334
385
335
386
def create_scheduler (session_name : str , ** kwargs : Any ) -> SlurmScheduler :
336
387
return SlurmScheduler (
0 commit comments