Skip to content

Commit ca1973d

Browse files
Eugen Hotajfacebook-github-bot
Eugen Hotaj
authored andcommitted
Create serializable TorchXRunner.
Summary: Forks TorchX's Ax TorchXRunner into one which can be serialized across runtimes. Also updates TorchXScheduler and TorchXRunner to expose a `stop` method for trials. Reviewed By: Balandat Differential Revision: D34214986 fbshipit-source-id: 0719212b244a4c8e3efa5adf27f9fb1df1dee397
1 parent 1137162 commit ca1973d

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

torchx/runtime/hpo/ax.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import inspect
9-
from typing import Any, Callable, Dict, Mapping, Optional, Set, cast
9+
from typing import Iterable, Any, Callable, Dict, Mapping, Optional, Set, cast
1010

1111
import pandas as pd
1212
from ax.core import Trial
@@ -206,6 +206,28 @@ def run(self, trial: BaseTrial) -> Dict[str, Any]:
206206
_TORCHX_TRACKER_BASE: self._tracker_base,
207207
}
208208

209+
def poll_trial_status(
210+
self, trials: Iterable[BaseTrial]
211+
) -> Dict[TrialStatus, Set[int]]:
212+
"""Returns the statuses of the given trials."""
213+
trial_statuses: Dict[TrialStatus, Set[int]] = {}
214+
215+
for trial in trials:
216+
app_handle: str = trial.run_metadata[_TORCHX_APP_HANDLE]
217+
app_status: Optional[AppStatus] = self._torchx_runner.status(app_handle)
218+
assert app_status is not None
219+
trial_status = APP_STATE_TO_TRIAL_STATUS[app_status.state]
220+
indices = trial_statuses.setdefault(trial_status, set())
221+
indices.add(trial.index)
222+
223+
return trial_statuses
224+
225+
def stop(self, trial: BaseTrial, reason: Optional[str] = None) -> Dict[str, Any]:
226+
"""Kill the given trial."""
227+
app_handle: str = trial.run_metadata[_TORCHX_APP_HANDLE]
228+
self._torchx_runner.stop(app_handle)
229+
return {"reason": reason} if reason else {}
230+
209231

210232
class TorchXScheduler(ax_Scheduler):
211233
"""
@@ -219,22 +241,6 @@ class TorchXScheduler(ax_Scheduler):
219241
220242
"""
221243

222-
def poll_trial_status(
223-
self, poll_all_trial_statuses: bool = False
224-
) -> Dict[TrialStatus, Set[int]]:
225-
trial_statuses: Dict[TrialStatus, Set[int]] = {}
226-
227-
for trial in self.running_trials:
228-
app_handle: str = trial.run_metadata[_TORCHX_APP_HANDLE]
229-
torchx_runner = trial.run_metadata[_TORCHX_RUNNER]
230-
app_status: AppStatus = torchx_runner.status(app_handle)
231-
trial_status = APP_STATE_TO_TRIAL_STATUS[app_status.state]
232-
233-
indices = trial_statuses.setdefault(trial_status, set())
234-
indices.add(trial.index)
235-
236-
return trial_statuses
237-
238244
def poll_available_capacity(self) -> int:
239245
"""
240246
Used when ``run_trials_in_batches`` option is set.
@@ -249,7 +255,6 @@ def poll_available_capacity(self) -> int:
249255
and scheduling policies.
250256
251257
"""
252-
253258
return (
254259
-1
255260
if self.generation_strategy._curr.max_parallelism is None

0 commit comments

Comments
 (0)