Skip to content

Commit 663f78a

Browse files
sarckkfacebook-github-bot
authored andcommitted
Do not ignore None arg when pipelining
Summary: if `arg` to embedding module is `None`, we would ignore it. However, now we also use `_get_node_args_helper` to generate arg list info for preproc modules, and sometimes `None` is passed in as arg/kwarg. With changes in pytorch#2342, we can now handle constants. For backward compatibility, adding an optional flag to indicate to `_get_node_args_helper` that we are handling preproc modules. Reviewed By: xing-liu Differential Revision: D61938346
1 parent 255d254 commit 663f78a

File tree

1 file changed

+13
-2
lines changed
  • torchrec/distributed/train_pipeline

1 file changed

+13
-2
lines changed

torchrec/distributed/train_pipeline/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,14 +744,17 @@ def _get_node_args_helper(
744744
pipelined_preprocs: Set[PipelinedPreproc],
745745
context: TrainPipelineContext,
746746
pipeline_preproc: bool,
747+
# Add `None` constants to arg info only for preproc modules
748+
# Defaults to False for backward compatibility
749+
for_preproc_module: bool = False,
747750
) -> Tuple[List[ArgInfo], int]:
748751
"""
749752
Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s.
750753
It also counts the number of (args + kwargs) found.
751754
"""
752755
arg_info_list = [ArgInfo([], [], [], [], None) for _ in range(len(arguments))]
753756
for arg, arg_info in zip(arguments, arg_info_list):
754-
if arg is None:
757+
if not for_preproc_module and arg is None:
755758
num_found += 1
756759
continue
757760
while True:
@@ -912,7 +915,12 @@ def _get_node_args_helper(
912915
# is either made of preproc module or non-modifying train batch input
913916
# transformations
914917
preproc_args, num_found_safe_preproc_args = _get_node_args(
915-
model, child_node, pipelined_preprocs, context, pipeline_preproc
918+
model,
919+
child_node,
920+
pipelined_preprocs,
921+
context,
922+
pipeline_preproc,
923+
True,
916924
)
917925
if num_found_safe_preproc_args == total_num_args:
918926
logger.info(
@@ -957,6 +965,7 @@ def _get_node_args(
957965
pipelined_preprocs: Set[PipelinedPreproc],
958966
context: TrainPipelineContext,
959967
pipeline_preproc: bool,
968+
for_preproc_module: bool = False,
960969
) -> Tuple[List[ArgInfo], int]:
961970
num_found = 0
962971

@@ -967,6 +976,7 @@ def _get_node_args(
967976
pipelined_preprocs,
968977
context,
969978
pipeline_preproc,
979+
for_preproc_module,
970980
)
971981
kwargs_arg_info_list, num_found = _get_node_args_helper(
972982
model,
@@ -975,6 +985,7 @@ def _get_node_args(
975985
pipelined_preprocs,
976986
context,
977987
pipeline_preproc,
988+
for_preproc_module,
978989
)
979990

980991
# Replace with proper names for kwargs

0 commit comments

Comments
 (0)