Skip to content

Commit 5cea44d

Browse files
aivanoufacebook-github-bot
authored andcommitted
Make docstring optional (#259)
Summary: Pull Request resolved: #259 * Refactor docstring functions: combines two functions that retrieve docstring into one * Make docstring optional * Remove docstring validator Differential Revision: D31671125 fbshipit-source-id: 2cb867f06742287019f628046eacdfca374e9aa6
1 parent 95ea9f5 commit 5cea44d

File tree

6 files changed

+118
-226
lines changed

6 files changed

+118
-226
lines changed

torchx/specs/api.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Generic,
2020
Iterator,
2121
List,
22-
Mapping,
2322
Optional,
2423
Tuple,
2524
Type,
@@ -28,8 +27,7 @@
2827
)
2928

3029
import yaml
31-
from pyre_extensions import none_throws
32-
from torchx.specs.file_linter import parse_fn_docstring
30+
from torchx.specs.file_linter import get_fn_docstring
3331
from torchx.util.types import decode_from_string, decode_optional, is_bool, is_primitive
3432

3533

@@ -748,22 +746,21 @@ def get_argparse_param_type(parameter: inspect.Parameter) -> Callable[[str], obj
748746
return str
749747

750748

751-
def _create_args_parser(
752-
fn_name: str,
753-
parameters: Mapping[str, inspect.Parameter],
754-
function_desc: str,
755-
args_desc: Dict[str, str],
756-
) -> argparse.ArgumentParser:
749+
def _create_args_parser(app_fn: Callable[..., AppDef]) -> argparse.ArgumentParser:
750+
parameters = inspect.signature(app_fn).parameters
751+
function_desc, args_desc = get_fn_docstring(app_fn)
757752
script_parser = argparse.ArgumentParser(
758-
prog=f"torchx run ...torchx_params... {fn_name} ",
753+
prog=f"torchx run <<torchx_params>> {app_fn.__name__} ",
759754
description=f"App spec: {function_desc}",
755+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
760756
)
761757

762758
remainder_arg = []
763759

764760
for param_name, parameter in parameters.items():
761+
param_desc = args_desc[parameter.name]
765762
args: Dict[str, Any] = {
766-
"help": args_desc[param_name],
763+
"help": param_desc,
767764
"type": get_argparse_param_type(parameter),
768765
}
769766
if parameter.default != inspect.Parameter.empty:
@@ -788,20 +785,15 @@ def _create_args_parser(
788785
def _get_function_args(
789786
app_fn: Callable[..., AppDef], app_args: List[str]
790787
) -> Tuple[List[object], List[str], Dict[str, object]]:
791-
docstring = none_throws(inspect.getdoc(app_fn))
792-
function_desc, args_desc = parse_fn_docstring(docstring)
793-
794-
parameters = inspect.signature(app_fn).parameters
795-
script_parser = _create_args_parser(
796-
app_fn.__name__, parameters, function_desc, args_desc
797-
)
788+
script_parser = _create_args_parser(app_fn)
798789

799790
parsed_args = script_parser.parse_args(app_args)
800791

801792
function_args = []
802793
var_arg = []
803794
kwargs = {}
804795

796+
parameters = inspect.signature(app_fn).parameters
805797
for param_name, parameter in parameters.items():
806798
arg_value = getattr(parsed_args, param_name)
807799
parameter_type = parameter.annotation

torchx/specs/file_linter.py

Lines changed: 28 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import abc
99
import ast
10+
import inspect
1011
from dataclasses import dataclass
11-
from typing import Dict, List, Optional, Tuple, cast
12+
from typing import Dict, List, Optional, Tuple, cast, Callable
1213

1314
from docstring_parser import parse
1415
from pyre_extensions import none_throws
@@ -18,53 +19,40 @@
1819
# pyre-ignore-all-errors[16]
1920

2021

21-
def get_arg_names(app_specs_func_def: ast.FunctionDef) -> List[str]:
22-
arg_names = []
23-
fn_args = app_specs_func_def.args
24-
for arg_def in fn_args.args:
25-
arg_names.append(arg_def.arg)
26-
if fn_args.vararg:
27-
arg_names.append(fn_args.vararg.arg)
28-
for arg in fn_args.kwonlyargs:
29-
arg_names.append(arg.arg)
30-
return arg_names
22+
def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str, str]:
23+
parameters = inspect.signature(fn).parameters
24+
args_decs = {}
25+
for parameter_name in parameters.keys():
26+
args_decs[parameter_name] = parameter_name
27+
return args_decs
3128

3229

33-
def parse_fn_docstring(func_description: str) -> Tuple[str, Dict[str, str]]:
30+
def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
3431
"""
35-
Given a docstring in a google-style format, returns the function description and
36-
description of all arguments.
37-
See: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
32+
Parses the function and arguments description from the provided function. Docstring should be in
33+
gogle-style format: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
34+
35+
If function has no docstring, the function descriptoin will be the name of the function, and
36+
the arguments descriptions will be names of the arguments.
37+
38+
The arguments that are not present in the docstring will contain their names as description
39+
40+
Args:
41+
fn: Function with or without docstring
42+
43+
Returns:
44+
function description, arguments description where key is the name of the argument and value
45+
if the description
3846
"""
39-
args_description = {}
47+
args_description = _get_default_arguments_descriptions(fn)
48+
func_description = inspect.getdoc(fn)
49+
if not func_description:
50+
return fn.__name__, args_description
4051
docstring = parse(func_description)
4152
for param in docstring.params:
4253
args_description[param.arg_name] = param.description
4354
short_func_description = docstring.short_description
44-
return (short_func_description or "", args_description)
45-
46-
47-
def _get_fn_docstring(
48-
source: str, function_name: str
49-
) -> Optional[Tuple[str, Dict[str, str]]]:
50-
module = ast.parse(source)
51-
for expr in module.body:
52-
if type(expr) == ast.FunctionDef:
53-
func_def = cast(ast.FunctionDef, expr)
54-
if func_def.name == function_name:
55-
docstring = ast.get_docstring(func_def)
56-
if not docstring:
57-
return None
58-
return parse_fn_docstring(docstring)
59-
return None
60-
61-
62-
def get_short_fn_description(path: str, function_name: str) -> Optional[str]:
63-
source = read_conf_file(path)
64-
docstring = _get_fn_docstring(source, function_name)
65-
if not docstring:
66-
return None
67-
return docstring[0]
55+
return (short_func_description or fn.__name__, args_description)
6856

6957

7058
@dataclass
@@ -91,38 +79,6 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
9179
)
9280

9381

94-
class TorchxDocstringValidator(TorchxFunctionValidator):
95-
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
96-
"""
97-
Validates the docstring of the `get_app_spec` function. Criteria:
98-
* There mast be google-style docstring
99-
* If there are more than zero arguments, there mast be a `Args:` section defined
100-
with all arguments included.
101-
"""
102-
docsting = ast.get_docstring(app_specs_func_def)
103-
lineno = app_specs_func_def.lineno
104-
if not docsting:
105-
desc = (
106-
f"`{app_specs_func_def.name}` is missing a Google Style docstring, please add one. "
107-
"For more information on the docstring format see: "
108-
"https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html"
109-
)
110-
return [self._gen_linter_message(desc, lineno)]
111-
112-
arg_names = get_arg_names(app_specs_func_def)
113-
_, docstring_arg_defs = parse_fn_docstring(docsting)
114-
missing_args = [
115-
arg_name for arg_name in arg_names if arg_name not in docstring_arg_defs
116-
]
117-
if len(missing_args) > 0:
118-
desc = (
119-
f"`{app_specs_func_def.name}` not all function arguments are present"
120-
f" in the docstring. Missing args: {missing_args}"
121-
)
122-
return [self._gen_linter_message(desc, lineno)]
123-
return []
124-
125-
12682
class TorchxFunctionArgsValidator(TorchxFunctionValidator):
12783
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
12884
linter_errors = []
@@ -149,7 +105,6 @@ def _validate_arg_def(
149105
)
150106
]
151107
if isinstance(arg_def.annotation, ast.Name):
152-
# TODO(aivanou): add support for primitive type check
153108
return []
154109
complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
155110
if complex_type_def.value.id == "Optional":
@@ -239,12 +194,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):
239194
Visitor that finds the component_function and runs registered validators on it.
240195
Current registered validators:
241196
242-
* TorchxDocstringValidator - validates the docstring of the function.
243-
Criteria:
244-
* There format should be google-python
245-
* If there are more than zero arguments defined, there
246-
should be obligatory `Args:` section that describes each argument on a new line.
247-
248197
* TorchxFunctionArgsValidator - validates arguments of the function.
249198
Criteria:
250199
* Each argument should be annotated with the type
@@ -260,7 +209,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):
260209

261210
def __init__(self, component_function_name: str) -> None:
262211
self.validators = [
263-
TorchxDocstringValidator(),
264212
TorchxFunctionArgsValidator(),
265213
TorchxReturnValidator(),
266214
]

torchx/specs/finder.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pyre_extensions import none_throws
1919
from torchx.specs import AppDef
20-
from torchx.specs.file_linter import get_short_fn_description, validate
20+
from torchx.specs.file_linter import get_fn_docstring, validate
2121
from torchx.util import entrypoints
2222
from torchx.util.io import read_conf_file
2323

@@ -40,14 +40,15 @@ class _Component:
4040
Args:
4141
name: The name of the component, which usually MODULE_PATH.FN_NAME
4242
description: The description of the component, taken from the desrciption
43-
of the function that creates component
43+
of the function that creates component. In case of no docstring, description
44+
will be the same as name
4445
fn_name: Function name that creates component
4546
fn: Function that creates component
4647
validation_errors: Validation errors
4748
"""
4849

4950
name: str
50-
description: Optional[str]
51+
description: str
5152
fn_name: str
5253
fn: Callable[..., AppDef]
5354
validation_errors: List[str]
@@ -119,9 +120,10 @@ def _get_components_from_dir(
119120
search_pattern = os.path.join(search_dir, "**", "*.py")
120121
component_defs = []
121122
for filepath in glob.glob(search_pattern, recursive=True):
122-
module = self._try_load_module(
123-
self._get_module_name(filepath, search_dir, base_module)
124-
)
123+
module_name = self._get_module_name(filepath, search_dir, base_module)
124+
if module_name.startswith("torchx.components.base"):
125+
continue
126+
module = self._try_load_module(module_name)
125127
defs = self._get_components_from_module(base_module, module)
126128
component_defs += defs
127129
return component_defs
@@ -146,7 +148,7 @@ def _get_components_from_module(
146148
module_path = os.path.abspath(module.__file__)
147149
for function_name, function in functions:
148150
linter_errors = validate(module_path, function_name)
149-
component_desc = get_short_fn_description(module_path, function_name)
151+
component_desc, _ = get_fn_docstring(function)
150152
component_def = _Component(
151153
name=self._get_component_name(
152154
base_module, module.__name__, function_name
@@ -193,7 +195,6 @@ def find(self) -> List[_Component]:
193195
validation_errors = self._get_validation_errors(
194196
self._filepath, self._function_name
195197
)
196-
fn_desc = get_short_fn_description(self._filepath, self._function_name)
197198

198199
file_source = read_conf_file(self._filepath)
199200
namespace = globals()
@@ -203,6 +204,7 @@ def find(self) -> List[_Component]:
203204
f"Function {self._function_name} does not exist in file {self._filepath}"
204205
)
205206
app_fn = namespace[self._function_name]
207+
fn_desc, _ = get_fn_docstring(app_fn)
206208
return [
207209
_Component(
208210
name=f"{self._filepath}:{self._function_name}",

torchx/specs/test/api_test.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import argparse
89
import sys
910
import unittest
1011
from dataclasses import asdict
@@ -33,6 +34,7 @@
3334
make_app_handle,
3435
parse_app_handle,
3536
runopts,
37+
_create_args_parser,
3638
)
3739

3840

@@ -463,11 +465,6 @@ def _test_complex_fn(
463465
app_name: AppDef name
464466
containers: List of containers
465467
roles_scripts: Dict role_name -> role_script
466-
num_cpus: List of cpus per role
467-
num_gpus: Dict role_name -> gpus used for role
468-
nnodes: Num replicas per role
469-
first_arg: First argument to the user script
470-
roles_args: Roles args
471468
"""
472469
num_roles = len(roles_scripts)
473470
if not num_cpus:
@@ -710,3 +707,28 @@ def test_varargs_only_arg_first(self) -> None:
710707
_TEST_VAR_ARGS_FIRST,
711708
(("fooval", "--foo", "barval", "arg1", "arg2"), "asdf"),
712709
)
710+
711+
def _get_argument_help(
712+
self, parser: argparse.ArgumentParser, name: str
713+
) -> Optional[str]:
714+
actions = parser._actions
715+
for action in actions:
716+
if action.dest == name:
717+
return action.help
718+
return None
719+
720+
def test_argparster_complex_fn_partial(self) -> None:
721+
parser = _create_args_parser(_test_complex_fn)
722+
self.assertEqual("AppDef name", self._get_argument_help(parser, "app_name"))
723+
self.assertEqual(
724+
"List of containers", self._get_argument_help(parser, "containers")
725+
)
726+
self.assertEqual(
727+
"Dict role_name -> role_script",
728+
self._get_argument_help(parser, "roles_scripts"),
729+
)
730+
self.assertEqual("num_cpus", self._get_argument_help(parser, "num_cpus"))
731+
self.assertEqual("num_gpus", self._get_argument_help(parser, "num_gpus"))
732+
self.assertEqual("nnodes", self._get_argument_help(parser, "nnodes"))
733+
self.assertEqual("first_arg", self._get_argument_help(parser, "first_arg"))
734+
self.assertEqual("roles_args", self._get_argument_help(parser, "roles_args"))

0 commit comments

Comments
 (0)