Skip to content

Commit c37cfd7

Browse files
Kiuk Chungfacebook-github-bot
Kiuk Chung
authored andcommitted
(torchx/config) support builtin argument defaults from .torchxconfig
Summary: **Summary:** Makes it possible to specify component parameter defaults in `.torchxconfig`. See changes to `.torchxconfig` files included in this diff and the *Test Plan* section for example usage and config specification. **Motivation:** Useful UX for those using builtin components that have required params (b/c no "global" defaults exist universally and hence cannot be specified as defaults in the component function declaration) that are always static for a particular user/team's use case of the builtin **Example:** `image` in `dist.ddp` will in most cases be some constant for the team but no universal default exists (and hence cannot be specified in the function declaration of `dist.ddp` itself) and is cumbersome to specify it all the time in the commandline. **Alternative:** is to copy the builtin as a separate component and hardcode (or default in the function declaration) the desired fields, but this requires the user to fork the builtin, which is sub-optimal for those in the "exploration/dev" phase and currently uninterested in productionalizing the component. **Other Notes:** While working on this feature, I've noticed a few improvements/cleanups that we need to work on which I'm tracking as [issue-368](#368). We need to push this code in the interest of time, and I've done as much as I could to NOT change any major APIs until we address the issues properly through issue-368. Reviewed By: aivanou Differential Revision: D33576756 fbshipit-source-id: b65af48a570cc83c366df4eb71a8583a0be6018f
1 parent 782f14d commit c37cfd7

File tree

11 files changed

+495
-97
lines changed

11 files changed

+495
-97
lines changed

torchx/cli/cmd_run.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
import threading
1212
from dataclasses import asdict
1313
from pprint import pformat
14-
from typing import Dict, List, Optional, Type
14+
from typing import Dict, List, Optional, Tuple, Type
1515

1616
import torchx.specs as specs
1717
from pyre_extensions import none_throws
1818
from torchx.cli.cmd_base import SubCommand
1919
from torchx.cli.cmd_log import get_logs
2020
from torchx.runner import Runner, config
21-
from torchx.runner.workspaces import get_workspace_runner, WorkspaceRunner
21+
from torchx.runner.config import load_sections
22+
from torchx.runner.workspaces import WorkspaceRunner, get_workspace_runner
2223
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
2324
from torchx.specs import CfgVal
2425
from torchx.specs.finder import (
@@ -31,6 +32,10 @@
3132
from torchx.util.types import to_dict
3233

3334

35+
MISSING_COMPONENT_ERROR_MSG = (
36+
"missing component name, either provide it from the CLI or in .torchxconfig"
37+
)
38+
3439
logger: logging.Logger = logging.getLogger(__name__)
3540

3641

@@ -61,6 +66,54 @@ def _parse_run_config(arg: str, scheduler_opts: specs.runopts) -> Dict[str, CfgV
6166
return conf
6267

6368

69+
def _parse_component_name_and_args(
70+
component_name_and_args: List[str],
71+
subparser: argparse.ArgumentParser,
72+
dirs: Optional[List[str]] = None, # for testing only
73+
) -> Tuple[str, List[str]]:
74+
"""
75+
Given a list of nargs parsed from commandline, parses out the component name
76+
and component args. If component name is not found in the list, then
77+
the default component is loaded from the [cli:run] component section in
78+
.torchxconfig. If no default config is specified in .torchxconfig, then
79+
this method errors out to the specified subparser.
80+
81+
This method deals with the following input list:
82+
83+
1. [$component_name, *$component_args]
84+
- Example: ["utils.echo", "--msg", "hello"] or ["utils.echo"]
85+
- Note: component name and args both in list
86+
2. [*$component_args]
87+
- Example: ["--msg", "hello"] or []
88+
- Note: component name loaded from .torchxconfig, args in list
89+
- Note: assumes list is only args if the first element
90+
looks like an option (e.g. starts with "-")
91+
92+
"""
93+
component = config.get_config(prefix="cli", name="run", key="component", dirs=dirs)
94+
component_args = []
95+
96+
# make a copy of the input list to guard against side-effects
97+
args = list(component_name_and_args)
98+
99+
if len(args) > 0:
100+
# `--` is used to delimit between run's options and nargs which includes component args
101+
# argparse returns the delimiter as part of the nargs so just ignore it if present
102+
if args[0] == "--":
103+
args = args[1:]
104+
105+
if args[0].startswith("-"):
106+
component_args = args
107+
else: # first element is NOT an option; then it must be a component name
108+
component = args[0]
109+
component_args = args[1:]
110+
111+
if not component:
112+
subparser.error(MISSING_COMPONENT_ERROR_MSG)
113+
114+
return component, component_args
115+
116+
64117
class CmdBuiltins(SubCommand):
65118
def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
66119
subparser.add_argument(
@@ -126,7 +179,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
126179
help="Stream logs while waiting for app to finish.",
127180
)
128181
subparser.add_argument(
129-
"conf_args",
182+
"component_name_and_args",
130183
nargs=argparse.REMAINDER,
131184
)
132185

@@ -143,33 +196,29 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
143196
scheduler_opts = run_opts[args.scheduler]
144197
cfg = _parse_run_config(args.scheduler_args, scheduler_opts)
145198
config.apply(scheduler=args.scheduler, cfg=cfg)
199+
146200
config_files = config.find_configs()
147201
workspace = (
148202
"file://" + os.path.dirname(config_files[0]) if config_files else None
149203
)
204+
component, component_args = _parse_component_name_and_args(
205+
args.component_name_and_args,
206+
none_throws(self._subparser),
207+
)
150208

151-
if len(args.conf_args) < 1:
152-
none_throws(self._subparser).error(
153-
"the following arguments are required: conf_file, conf_args"
154-
)
155-
156-
# Python argparse would remove `--` if it was the first argument. This
157-
# does not work well for torchx, since torchx.specs.api uses another argparser to
158-
# parse component arguments.
159-
conf_file, conf_args = args.conf_args[0], args.conf_args[1:]
160209
try:
161210
if args.dryrun:
162211
if isinstance(runner, WorkspaceRunner):
163212
dryrun_info = runner.dryrun_component(
164-
conf_file,
165-
conf_args,
213+
component,
214+
component_args,
166215
args.scheduler,
167216
workspace=workspace,
168217
cfg=cfg,
169218
)
170219
else:
171220
dryrun_info = runner.dryrun_component(
172-
conf_file, conf_args, args.scheduler, cfg=cfg
221+
component, component_args, args.scheduler, cfg=cfg
173222
)
174223
logger.info(
175224
"\n=== APPLICATION ===\n"
@@ -180,16 +229,16 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
180229
else:
181230
if isinstance(runner, WorkspaceRunner):
182231
app_handle = runner.run_component(
183-
conf_file,
184-
conf_args,
232+
component,
233+
component_args,
185234
args.scheduler,
186235
workspace=workspace,
187236
cfg=cfg,
188237
)
189238
else:
190239
app_handle = runner.run_component(
191-
conf_file,
192-
conf_args,
240+
component,
241+
component_args,
193242
args.scheduler,
194243
cfg=cfg,
195244
)
@@ -208,7 +257,7 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
208257
self._wait_and_exit(runner, app_handle, log=args.log)
209258

210259
except (ComponentValidationException, ComponentNotFoundException) as e:
211-
error_msg = f"\nFailed to run component `{conf_file}` got errors: \n {e}"
260+
error_msg = f"\nFailed to run component `{component}` got errors: \n {e}"
212261
logger.error(error_msg)
213262
sys.exit(1)
214263
except specs.InvalidRunConfigException as e:
@@ -223,7 +272,8 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
223272

224273
def run(self, args: argparse.Namespace) -> None:
225274
os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "cli_run")
226-
with get_workspace_runner() as runner:
275+
component_defaults = load_sections(prefix="component")
276+
with get_workspace_runner(component_defaults=component_defaults) as runner:
227277
self._run(runner, args)
228278

229279
def _wait_and_exit(self, runner: Runner, app_handle: str, log: bool) -> None:

torchx/cli/test/cmd_run_test.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
from typing import Generator, List
1818
from unittest.mock import MagicMock, patch
1919

20-
from torchx.cli.cmd_run import CmdBuiltins, CmdRun, _parse_run_config, logger
20+
from torchx.cli.cmd_run import (
21+
CmdBuiltins,
22+
CmdRun,
23+
_parse_component_name_and_args,
24+
_parse_run_config,
25+
logger,
26+
)
2127
from torchx.schedulers.local_scheduler import SignalException
2228
from torchx.specs import runopts
2329

@@ -198,6 +204,63 @@ def test_parse_runopts(self) -> None:
198204
for k, v in expected_args.items():
199205
self.assertEqual(v, runconfig.get(k))
200206

207+
def test_parse_component_name_and_args_no_default(self) -> None:
208+
sp = argparse.ArgumentParser(prog="test")
209+
self.assertEqual(
210+
("utils.echo", []),
211+
_parse_component_name_and_args(["utils.echo"], sp),
212+
)
213+
self.assertEqual(
214+
("utils.echo", []),
215+
_parse_component_name_and_args(["--", "utils.echo"], sp),
216+
)
217+
self.assertEqual(
218+
("utils.echo", ["--msg", "hello"]),
219+
_parse_component_name_and_args(["utils.echo", "--msg", "hello"], sp),
220+
)
221+
222+
with self.assertRaises(SystemExit):
223+
_parse_component_name_and_args(["--msg", "hello"], sp)
224+
225+
with self.assertRaises(SystemExit):
226+
_parse_component_name_and_args(["-m", "hello"], sp)
227+
228+
def test_parse_component_name_and_args_with_default(self) -> None:
229+
sp = argparse.ArgumentParser(prog="test")
230+
dirs = [str(self.tmpdir)]
231+
232+
with open(Path(self.tmpdir) / ".torchxconfig", "w") as f:
233+
f.write(
234+
"""#
235+
[cli:run]
236+
component = custom.echo
237+
"""
238+
)
239+
240+
self.assertEqual(
241+
("utils.echo", []), _parse_component_name_and_args(["utils.echo"], sp, dirs)
242+
)
243+
self.assertEqual(
244+
("utils.echo", ["--msg", "hello"]),
245+
_parse_component_name_and_args(["utils.echo", "--msg", "hello"], sp, dirs),
246+
)
247+
self.assertEqual(
248+
("custom.echo", []),
249+
_parse_component_name_and_args([], sp, dirs),
250+
)
251+
self.assertEqual(
252+
("custom.echo", ["--msg", "hello"]),
253+
_parse_component_name_and_args(["--", "--msg", "hello"], sp, dirs),
254+
)
255+
self.assertEqual(
256+
("custom.echo", ["--msg", "hello"]),
257+
_parse_component_name_and_args(["--msg", "hello"], sp, dirs),
258+
)
259+
self.assertEqual(
260+
("custom.echo", ["-m", "hello"]),
261+
_parse_component_name_and_args(["-m", "hello"], sp, dirs),
262+
)
263+
201264

202265
class CmdBuiltinTest(unittest.TestCase):
203266
def test_run(self) -> None:

torchx/examples/apps/.torchxconfig

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,20 @@ enableGracefulPreemption = False
2020
secure_group = pytorch_r2p
2121
entitlement = default
2222
proxy_workflow_image = None
23+
24+
[cli:run]
25+
component = fb.dist.hpc
26+
27+
# TODO need to add hydra to bento_kernel_torchx and make that the default img
28+
[component:fb.dist.ddp]
29+
img = bento_kernel_pytorch_lightning
30+
m = fb/compute_world_size/main.py
31+
32+
[component:fb.dist.ddp2]
33+
img = bento_kernel_pytorch_lightning
34+
m = fb/compute_world_size/main.py
35+
36+
[component:fb.dist.hpc]
37+
img = bento_kernel_pytorch_lightning
38+
m = fb/compute_world_size/main.py
39+

torchx/runner/api.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
self,
5151
name: str,
5252
schedulers: Dict[SchedulerBackend, Scheduler],
53+
component_defaults: Optional[Dict[str, Dict[str, str]]] = None,
5354
) -> None:
5455
"""
5556
Creates a new runner instance.
@@ -63,6 +64,9 @@ def __init__(
6364
self._schedulers = schedulers
6465
self._apps: Dict[AppHandle, AppDef] = {}
6566

67+
# component_name -> map of component_fn_param_name -> user-specified default val encoded as str
68+
self._component_defaults: Dict[str, Dict[str, str]] = component_defaults or {}
69+
6670
def __enter__(self) -> "Runner":
6771
return self
6872

@@ -147,7 +151,11 @@ def dryrun_component(
147151
component, but just returns what "would" have run.
148152
"""
149153
component_def = get_component(component)
150-
app = from_function(component_def.fn, component_args)
154+
app = from_function(
155+
component_def.fn,
156+
component_args,
157+
self._component_defaults.get(component, None),
158+
)
151159
return self.dryrun(app, scheduler, cfg)
152160

153161
def run(
@@ -521,7 +529,11 @@ def __repr__(self) -> str:
521529
return f"Runner(name={self._name}, schedulers={self._schedulers}, apps={self._apps})"
522530

523531

524-
def get_runner(name: Optional[str] = None, **scheduler_params: Any) -> Runner:
532+
def get_runner(
533+
name: Optional[str] = None,
534+
component_defaults: Optional[Dict[str, Dict[str, str]]] = None,
535+
**scheduler_params: Any,
536+
) -> Runner:
525537
"""
526538
Convenience method to construct and get a Runner object. Usage:
527539
@@ -554,4 +566,4 @@ def get_runner(name: Optional[str] = None, **scheduler_params: Any) -> Runner:
554566
name = "torchx"
555567

556568
schedulers = get_schedulers(session_name=name, **scheduler_params)
557-
return Runner(name, schedulers)
569+
return Runner(name, schedulers, component_defaults)

0 commit comments

Comments
 (0)