6
6
# LICENSE file in the root directory of this source tree.
7
7
8
8
import inspect
9
- from typing import Any , Callable , Dict , Mapping , Optional , Set , cast
9
+ from typing import Iterable , Any , Callable , Dict , Mapping , Optional , Set , cast
10
10
11
11
import pandas as pd
12
12
from ax .core import Trial
@@ -206,6 +206,28 @@ def run(self, trial: BaseTrial) -> Dict[str, Any]:
206
206
_TORCHX_TRACKER_BASE : self ._tracker_base ,
207
207
}
208
208
209
+ def poll_trial_status (
210
+ self , trials : Iterable [BaseTrial ]
211
+ ) -> Dict [TrialStatus , Set [int ]]:
212
+ """Returns the statuses of the given trials."""
213
+ trial_statuses : Dict [TrialStatus , Set [int ]] = {}
214
+
215
+ for trial in trials :
216
+ app_handle : str = trial .run_metadata [_TORCHX_APP_HANDLE ]
217
+ app_status : Optional [AppStatus ] = self ._torchx_runner .status (app_handle )
218
+ assert app_status is not None
219
+ trial_status = APP_STATE_TO_TRIAL_STATUS [app_status .state ]
220
+ indices = trial_statuses .setdefault (trial_status , set ())
221
+ indices .add (trial .index )
222
+
223
+ return trial_statuses
224
+
225
+ def stop (self , trial : BaseTrial , reason : Optional [str ] = None ) -> Dict [str , Any ]:
226
+ """Kill the given trial."""
227
+ app_handle : str = trial .run_metadata [_TORCHX_APP_HANDLE ]
228
+ self ._torchx_runner .stop (app_handle )
229
+ return {"reason" : reason } if reason else {}
230
+
209
231
210
232
class TorchXScheduler (ax_Scheduler ):
211
233
"""
@@ -219,22 +241,6 @@ class TorchXScheduler(ax_Scheduler):
219
241
220
242
"""
221
243
222
- def poll_trial_status (
223
- self , poll_all_trial_statuses : bool = False
224
- ) -> Dict [TrialStatus , Set [int ]]:
225
- trial_statuses : Dict [TrialStatus , Set [int ]] = {}
226
-
227
- for trial in self .running_trials :
228
- app_handle : str = trial .run_metadata [_TORCHX_APP_HANDLE ]
229
- torchx_runner = trial .run_metadata [_TORCHX_RUNNER ]
230
- app_status : AppStatus = torchx_runner .status (app_handle )
231
- trial_status = APP_STATE_TO_TRIAL_STATUS [app_status .state ]
232
-
233
- indices = trial_statuses .setdefault (trial_status , set ())
234
- indices .add (trial .index )
235
-
236
- return trial_statuses
237
-
238
244
def poll_available_capacity (self ) -> int :
239
245
"""
240
246
Used when ``run_trials_in_batches`` option is set.
@@ -249,7 +255,6 @@ def poll_available_capacity(self) -> int:
249
255
and scheduling policies.
250
256
251
257
"""
252
-
253
258
return (
254
259
- 1
255
260
if self .generation_strategy ._curr .max_parallelism is None
0 commit comments