Skip to content

Commit b9b74ff

Browse files
sarckkfacebook-github-bot
authored andcommitted
Do not ignore None arg when pipelining (#2352)
Summary: Pull Request resolved: #2352 if `arg` to embedding module is `None`, we would ignore it. However, now we also use `_get_node_args_helper` to generate arg list info for preproc modules, and sometimes `None` is passed in as arg/kwarg. With changes in #2342, we can now handle constants. For backward compatibility, adding an optional flag to indicate to `_get_node_args_helper` that we are handling preproc modules. Reviewed By: xing-liu Differential Revision: D61938346
1 parent 875be19 commit b9b74ff

File tree

1 file changed

+13
-2
lines changed
  • torchrec/distributed/train_pipeline

1 file changed

+13
-2
lines changed

torchrec/distributed/train_pipeline/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,14 +744,17 @@ def _get_node_args_helper(
744744
pipelined_preprocs: Set[PipelinedPreproc],
745745
context: TrainPipelineContext,
746746
pipeline_preproc: bool,
747+
# Add `None` constants to arg info only for preproc modules
748+
# Defaults to False for backward compatibility
749+
for_preproc_module: bool = False,
747750
) -> Tuple[List[ArgInfo], int]:
748751
"""
749752
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
750753
It also counts the number of (args + kwargs) found.
751754
"""
752755
arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))]
753756
for arg, arg_info in zip(arguments, arg_info_list):
754-
if arg is None:
757+
if not for_preproc_module and arg is None:
755758
num_found += 1
756759
continue
757760
while True:
@@ -911,7 +914,12 @@ def _get_node_args_helper(
911914
# is either made of preproc module or non-modifying train batch input
912915
# transformations
913916
preproc_args, num_found_safe_preproc_args = _get_node_args(
914-
model, child_node, pipelined_preprocs, context, pipeline_preproc
917+
model,
918+
child_node,
919+
pipelined_preprocs,
920+
context,
921+
pipeline_preproc,
922+
True,
915923
)
916924
if num_found_safe_preproc_args == total_num_args:
917925
logger.info(
@@ -956,6 +964,7 @@ def _get_node_args(
956964
pipelined_preprocs: Set[PipelinedPreproc],
957965
context: TrainPipelineContext,
958966
pipeline_preproc: bool,
967+
for_preproc_module: bool = False,
959968
) -> Tuple[List[ArgInfo], int]:
960969
num_found = 0
961970

@@ -966,6 +975,7 @@ def _get_node_args(
966975
pipelined_preprocs,
967976
context,
968977
pipeline_preproc,
978+
for_preproc_module,
969979
)
970980
kwargs_arg_info_list, num_found = _get_node_args_helper(
971981
model,
@@ -974,6 +984,7 @@ def _get_node_args(
974984
pipelined_preprocs,
975985
context,
976986
pipeline_preproc,
987+
for_preproc_module,
977988
)
978989

979990
# Replace with proper names for kwargs

0 commit comments

Comments
 (0)